github.com/SigNoz/golang-migrate/v4@v4.0.0-20231005133642-7493dbaf5f5b/internal/cli/main.go (about)

     1  package cli
     2  
     3  import (
     4  	"flag"
     5  	"fmt"
     6  	"os"
     7  	"os/signal"
     8  	"strconv"
     9  	"strings"
    10  	"syscall"
    11  	"time"
    12  
    13  	"github.com/golang-migrate/migrate/v4"
    14  	"github.com/golang-migrate/migrate/v4/database"
    15  	"github.com/golang-migrate/migrate/v4/source"
    16  )
    17  
    18  const (
    19  	defaultTimeFormat = "20060102150405"
    20  	defaultTimezone   = "UTC"
    21  	createUsage       = `create [-ext E] [-dir D] [-seq] [-digits N] [-format] [-tz] NAME
    22  	   Create a set of timestamped up/down migrations titled NAME, in directory D with extension E.
    23  	   Use -seq option to generate sequential up/down migrations with N digits.
    24  	   Use -format option to specify a Go time format string. Note: migrations with the same time cause "duplicate migration version" error.
    25             Use -tz option to specify the timezone that will be used when generating non-sequential migrations (defaults: UTC).
    26  `
    27  	gotoUsage = `goto V       Migrate to version V`
    28  	upUsage   = `up [N]       Apply all or N up migrations`
    29  	downUsage = `down [N] [-all]    Apply all or N down migrations
    30  	Use -all to apply all down migrations`
    31  	dropUsage = `drop [-f]    Drop everything inside database
    32  	Use -f to bypass confirmation`
    33  	forceUsage = `force V      Set version V but don't run migration (ignores dirty state)`
    34  )
    35  
    36  func handleSubCmdHelp(help bool, usage string, flagSet *flag.FlagSet) {
    37  	if help {
    38  		fmt.Fprintln(os.Stderr, usage)
    39  		flagSet.PrintDefaults()
    40  		os.Exit(0)
    41  	}
    42  }
    43  
    44  func newFlagSetWithHelp(name string) (*flag.FlagSet, *bool) {
    45  	flagSet := flag.NewFlagSet(name, flag.ExitOnError)
    46  	helpPtr := flagSet.Bool("help", false, "Print help information")
    47  	return flagSet, helpPtr
    48  }
    49  
    50  // set main log
    51  var log = &Log{}
    52  
    53  func printUsageAndExit() {
    54  	flag.Usage()
    55  
    56  	// If a command is not found we exit with a status 2 to match the behavior
    57  	// of flag.Parse() with flag.ExitOnError when parsing an invalid flag.
    58  	os.Exit(2)
    59  }
    60  
    61  // Main function of a cli application. It is public for backwards compatibility with `cli` package
    62  func Main(version string) {
    63  	helpPtr := flag.Bool("help", false, "")
    64  	versionPtr := flag.Bool("version", false, "")
    65  	verbosePtr := flag.Bool("verbose", false, "")
    66  	templatePtr := flag.Bool("template", false, "")
    67  	prefetchPtr := flag.Uint("prefetch", 10, "")
    68  	lockTimeoutPtr := flag.Uint("lock-timeout", 15, "")
    69  	pathPtr := flag.String("path", "", "")
    70  	databasePtr := flag.String("database", "", "")
    71  	sourcePtr := flag.String("source", "", "")
    72  
    73  	flag.Usage = func() {
    74  		fmt.Fprintf(os.Stderr,
    75  			`Usage: migrate OPTIONS COMMAND [arg...]
    76         migrate [ -version | -help ]
    77  
    78  Options:
    79    -source          Location of the migrations (driver://url)
    80    -path            Shorthand for -source=file://path
    81    -database        Run migrations against this database (driver://url)
    82    -prefetch N      Number of migrations to load in advance before executing (default 10)
    83    -lock-timeout N  Allow N seconds to acquire database lock (default 15)
    84    -template        Treat migration files as go text templates; making environment variables accessible in your
    85                     migration files. 
    86                       i.e. If you set the LOCAL_WAREHOUSE environment variable to MY_DB and have a migration file with the
    87                       following contents:
    88                         INSERT INTO {{.LOCAL_WAREHOUSE}}.INVENTORY.RECORDS ('foo') VALUES ('bar');
    89                       it will be transformed into the following before being executed:
    90                         INSERT INTO MY_DB.INVENTORY.RECORDS ('foo') VALUES ('bar');
    91                     
    92                       Note that enabling templating requires that the contents of the migration file be brought into memory 
    93                       in order to perform the transformation, and streaming the file directly from the source driver into
    94                       the database is not currently possible with the go templating implementation.
    95  
    96                       See https://pkg.go.dev/text/template for more information on supported template formats.
    97    -verbose         Print verbose logging
    98    -version         Print version
    99    -help            Print usage
   100  
   101  Commands:
   102    %s
   103    %s
   104    %s
   105    %s
   106    %s
   107    %s
   108    version      Print current migration version
   109  
   110  Source drivers: `+strings.Join(source.List(), ", ")+`
   111  Database drivers: `+strings.Join(database.List(), ", ")+"\n", createUsage, gotoUsage, upUsage, downUsage, dropUsage, forceUsage)
   112  	}
   113  
   114  	flag.Parse()
   115  
   116  	// initialize logger
   117  	log.verbose = *verbosePtr
   118  
   119  	// show cli version
   120  	if *versionPtr {
   121  		fmt.Fprintln(os.Stderr, version)
   122  		os.Exit(0)
   123  	}
   124  
   125  	// show help
   126  	if *helpPtr {
   127  		flag.Usage()
   128  		os.Exit(0)
   129  	}
   130  
   131  	// translate -path into -source if given
   132  	if *sourcePtr == "" && *pathPtr != "" {
   133  		*sourcePtr = fmt.Sprintf("file://%v", *pathPtr)
   134  	}
   135  
   136  	// initialize migrate
   137  	// don't catch migraterErr here and let each command decide
   138  	// how it wants to handle the error
   139  	migrater, migraterErr := migrate.New(*sourcePtr, *databasePtr)
   140  	defer func() {
   141  		if migraterErr == nil {
   142  			if _, err := migrater.Close(); err != nil {
   143  				log.Println(err)
   144  			}
   145  		}
   146  	}()
   147  	if migraterErr == nil {
   148  		migrater.Log = log
   149  		migrater.PrefetchMigrations = *prefetchPtr
   150  		migrater.LockTimeout = time.Duration(int64(*lockTimeoutPtr)) * time.Second
   151  		if *templatePtr {
   152  			migrater.EnableTemplating = true
   153  		}
   154  
   155  		// handle Ctrl+c
   156  		signals := make(chan os.Signal, 1)
   157  		signal.Notify(signals, syscall.SIGINT)
   158  		go func() {
   159  			for range signals {
   160  				log.Println("Stopping after this running migration ...")
   161  				migrater.GracefulStop <- true
   162  				return
   163  			}
   164  		}()
   165  	}
   166  
   167  	startTime := time.Now()
   168  
   169  	if len(flag.Args()) < 1 {
   170  		printUsageAndExit()
   171  	}
   172  	args := flag.Args()[1:]
   173  
   174  	switch flag.Arg(0) {
   175  	case "create":
   176  
   177  		seq := false
   178  		seqDigits := 6
   179  
   180  		createFlagSet, help := newFlagSetWithHelp("create")
   181  		extPtr := createFlagSet.String("ext", "", "File extension")
   182  		dirPtr := createFlagSet.String("dir", "", "Directory to place file in (default: current working directory)")
   183  		formatPtr := createFlagSet.String("format", defaultTimeFormat, `The Go time format string to use. If the string "unix" or "unixNano" is specified, then the seconds or nanoseconds since January 1, 1970 UTC respectively will be used. Caution, due to the behavior of time.Time.Format(), invalid format strings will not error`)
   184  		timezoneName := createFlagSet.String("tz", defaultTimezone, `The timezone that will be used for generating timestamps (default: utc)`)
   185  		createFlagSet.BoolVar(&seq, "seq", seq, "Use sequential numbers instead of timestamps (default: false)")
   186  		createFlagSet.IntVar(&seqDigits, "digits", seqDigits, "The number of digits to use in sequences (default: 6)")
   187  
   188  		if err := createFlagSet.Parse(args); err != nil {
   189  			log.fatalErr(err)
   190  		}
   191  
   192  		handleSubCmdHelp(*help, createUsage, createFlagSet)
   193  
   194  		if createFlagSet.NArg() == 0 {
   195  			log.fatal("error: please specify name")
   196  		}
   197  		name := createFlagSet.Arg(0)
   198  
   199  		if *extPtr == "" {
   200  			log.fatal("error: -ext flag must be specified")
   201  		}
   202  
   203  		timezone, err := time.LoadLocation(*timezoneName)
   204  		if err != nil {
   205  			log.fatal(err)
   206  		}
   207  
   208  		if err := createCmd(*dirPtr, startTime.In(timezone), *formatPtr, name, *extPtr, seq, seqDigits, true); err != nil {
   209  			log.fatalErr(err)
   210  		}
   211  
   212  	case "goto":
   213  
   214  		gotoSet, helpPtr := newFlagSetWithHelp("goto")
   215  
   216  		if err := gotoSet.Parse(args); err != nil {
   217  			log.fatalErr(err)
   218  		}
   219  
   220  		handleSubCmdHelp(*helpPtr, gotoUsage, gotoSet)
   221  
   222  		if migraterErr != nil {
   223  			log.fatalErr(migraterErr)
   224  		}
   225  
   226  		if gotoSet.NArg() == 0 {
   227  			log.fatal("error: please specify version argument V")
   228  		}
   229  
   230  		v, err := strconv.ParseUint(gotoSet.Arg(0), 10, 64)
   231  		if err != nil {
   232  			log.fatal("error: can't read version argument V")
   233  		}
   234  
   235  		if err := gotoCmd(migrater, uint(v)); err != nil {
   236  			log.fatalErr(err)
   237  		}
   238  
   239  		if log.verbose {
   240  			log.Println("Finished after", time.Since(startTime))
   241  		}
   242  
   243  	case "up":
   244  		upSet, helpPtr := newFlagSetWithHelp("up")
   245  
   246  		if err := upSet.Parse(args); err != nil {
   247  			log.fatalErr(err)
   248  		}
   249  
   250  		handleSubCmdHelp(*helpPtr, upUsage, upSet)
   251  
   252  		if migraterErr != nil {
   253  			log.fatalErr(migraterErr)
   254  		}
   255  
   256  		limit := -1
   257  		if upSet.NArg() > 0 {
   258  			n, err := strconv.ParseUint(upSet.Arg(0), 10, 64)
   259  			if err != nil {
   260  				log.fatal("error: can't read limit argument N")
   261  			}
   262  			limit = int(n)
   263  		}
   264  
   265  		if err := upCmd(migrater, limit); err != nil {
   266  			log.fatalErr(err)
   267  		}
   268  
   269  		if log.verbose {
   270  			log.Println("Finished after", time.Since(startTime))
   271  		}
   272  
   273  	case "down":
   274  		downFlagSet, helpPtr := newFlagSetWithHelp("down")
   275  		applyAll := downFlagSet.Bool("all", false, "Apply all down migrations")
   276  
   277  		if err := downFlagSet.Parse(args); err != nil {
   278  			log.fatalErr(err)
   279  		}
   280  
   281  		handleSubCmdHelp(*helpPtr, downUsage, downFlagSet)
   282  
   283  		if migraterErr != nil {
   284  			log.fatalErr(migraterErr)
   285  		}
   286  
   287  		downArgs := downFlagSet.Args()
   288  		num, needsConfirm, err := numDownMigrationsFromArgs(*applyAll, downArgs)
   289  		if err != nil {
   290  			log.fatalErr(err)
   291  		}
   292  		if needsConfirm {
   293  			log.Println("Are you sure you want to apply all down migrations? [y/N]")
   294  			var response string
   295  			fmt.Scanln(&response)
   296  			response = strings.ToLower(strings.TrimSpace(response))
   297  
   298  			if response == "y" {
   299  				log.Println("Applying all down migrations")
   300  			} else {
   301  				log.fatal("Not applying all down migrations")
   302  			}
   303  		}
   304  
   305  		if err := downCmd(migrater, num); err != nil {
   306  			log.fatalErr(err)
   307  		}
   308  
   309  		if log.verbose {
   310  			log.Println("Finished after", time.Since(startTime))
   311  		}
   312  
   313  	case "drop":
   314  		dropFlagSet, help := newFlagSetWithHelp("drop")
   315  		forceDrop := dropFlagSet.Bool("f", false, "Force the drop command by bypassing the confirmation prompt")
   316  
   317  		if err := dropFlagSet.Parse(args); err != nil {
   318  			log.fatalErr(err)
   319  		}
   320  
   321  		handleSubCmdHelp(*help, dropUsage, dropFlagSet)
   322  
   323  		if !*forceDrop {
   324  			log.Println("Are you sure you want to drop the entire database schema? [y/N]")
   325  			var response string
   326  			fmt.Scanln(&response)
   327  			response = strings.ToLower(strings.TrimSpace(response))
   328  
   329  			if response == "y" {
   330  				log.Println("Dropping the entire database schema")
   331  			} else {
   332  				log.fatal("Aborted dropping the entire database schema")
   333  			}
   334  		}
   335  
   336  		if migraterErr != nil {
   337  			log.fatalErr(migraterErr)
   338  		}
   339  
   340  		if err := dropCmd(migrater); err != nil {
   341  			log.fatalErr(err)
   342  		}
   343  
   344  		if log.verbose {
   345  			log.Println("Finished after", time.Since(startTime))
   346  		}
   347  
   348  	case "force":
   349  		forceSet, helpPtr := newFlagSetWithHelp("force")
   350  
   351  		if err := forceSet.Parse(args); err != nil {
   352  			log.fatalErr(err)
   353  		}
   354  
   355  		handleSubCmdHelp(*helpPtr, forceUsage, forceSet)
   356  
   357  		if migraterErr != nil {
   358  			log.fatalErr(migraterErr)
   359  		}
   360  
   361  		if forceSet.NArg() == 0 {
   362  			log.fatal("error: please specify version argument V")
   363  		}
   364  
   365  		v, err := strconv.ParseInt(forceSet.Arg(0), 10, 64)
   366  		if err != nil {
   367  			log.fatal("error: can't read version argument V")
   368  		}
   369  
   370  		if v < -1 {
   371  			log.fatal("error: argument V must be >= -1")
   372  		}
   373  
   374  		if err := forceCmd(migrater, int(v)); err != nil {
   375  			log.fatalErr(err)
   376  		}
   377  
   378  		if log.verbose {
   379  			log.Println("Finished after", time.Since(startTime))
   380  		}
   381  
   382  	case "version":
   383  		if migraterErr != nil {
   384  			log.fatalErr(migraterErr)
   385  		}
   386  
   387  		if err := versionCmd(migrater); err != nil {
   388  			log.fatalErr(err)
   389  		}
   390  
   391  	default:
   392  		printUsageAndExit()
   393  	}
   394  }