github.com/iasthc/atlas/cmd/atlas@v0.0.0-20230523071841-73246df3f88d/internal/cmdapi/migrate.go (about)

     1  // Copyright 2021-present The Atlas Authors. All rights reserved.
     2  // This source code is licensed under the Apache 2.0 license found
     3  // in the LICENSE file in the root directory of this source tree.
     4  
     5  package cmdapi
     6  
     7  import (
     8  	"bytes"
     9  	"context"
    10  	"database/sql"
    11  	"errors"
    12  	"fmt"
    13  	"io"
    14  	"io/fs"
    15  	"net/url"
    16  	"os"
    17  	"os/exec"
    18  	"path"
    19  	"path/filepath"
    20  	"strconv"
    21  	"strings"
    22  	"text/template"
    23  	"text/template/parse"
    24  	"time"
    25  
    26  	"github.com/iasthc/atlas/cmd/atlas/internal/cloudapi"
    27  	"github.com/iasthc/atlas/cmd/atlas/internal/cmdext"
    28  	"github.com/iasthc/atlas/cmd/atlas/internal/cmdlog"
    29  	"github.com/iasthc/atlas/cmd/atlas/internal/lint"
    30  	cmdmigrate "github.com/iasthc/atlas/cmd/atlas/internal/migrate"
    31  	"github.com/iasthc/atlas/cmd/atlas/internal/migrate/ent"
    32  	"github.com/iasthc/atlas/cmd/atlas/internal/migrate/ent/revision"
    33  	"github.com/iasthc/atlas/sql/migrate"
    34  	"github.com/iasthc/atlas/sql/schema"
    35  	"github.com/iasthc/atlas/sql/sqlcheck"
    36  	"github.com/iasthc/atlas/sql/sqlclient"
    37  	"github.com/iasthc/atlas/sql/sqltool"
    38  
    39  	"github.com/google/uuid"
    40  	"github.com/hashicorp/hcl/v2"
    41  	"github.com/hashicorp/hcl/v2/hclparse"
    42  	"github.com/hashicorp/hcl/v2/hclsyntax"
    43  	"github.com/spf13/cobra"
    44  )
    45  
    46  func init() {
    47  	migrateCmd := migrateCmd()
    48  	migrateCmd.AddCommand(
    49  		migrateApplyCmd(),
    50  		migrateDiffCmd(),
    51  		migrateHashCmd(),
    52  		migrateImportCmd(),
    53  		migrateLintCmd(),
    54  		migrateNewCmd(),
    55  		migrateSetCmd(),
    56  		migrateStatusCmd(),
    57  		migrateValidateCmd(),
    58  	)
    59  	Root.AddCommand(migrateCmd)
    60  }
    61  
    62  // migrateCmd represents the subcommand 'atlas migrate'.
    63  func migrateCmd() *cobra.Command {
    64  	cmd := &cobra.Command{
    65  		Use:   "migrate",
    66  		Short: "Manage versioned migration files",
    67  		Long:  "'atlas migrate' wraps several sub-commands for migration management.",
    68  	}
    69  	addGlobalFlags(cmd.PersistentFlags())
    70  	return cmd
    71  }
    72  
    73  type migrateApplyFlags struct {
    74  	url             string
    75  	dirURL          string
    76  	revisionSchema  string
    77  	dryRun          bool
    78  	logFormat       string
    79  	lockTimeout     time.Duration
    80  	allowDirty      bool   // allow working on a database that already has resources
    81  	fromVersion     string // compute pending files based on this version
    82  	baselineVersion string // apply with this version as baseline
    83  	txMode          string // (none, file, all)
    84  }
    85  
    86  func (f *migrateApplyFlags) migrateOptions() (opts []migrate.ExecutorOption) {
    87  	if f.allowDirty {
    88  		opts = append(opts, migrate.WithAllowDirty(true))
    89  	}
    90  	if v := f.baselineVersion; v != "" {
    91  		opts = append(opts, migrate.WithBaselineVersion(v))
    92  	}
    93  	if v := f.fromVersion; v != "" {
    94  		opts = append(opts, migrate.WithFromVersion(v))
    95  	}
    96  	return
    97  }
    98  
    99  func migrateApplyCmd() *cobra.Command {
   100  	var (
   101  		flags migrateApplyFlags
   102  		cmd   = &cobra.Command{
   103  			Use:   "apply [flags] [amount]",
   104  			Short: "Applies pending migration files on the connected database.",
   105  			Long: `'atlas migrate apply' reads the migration state of the connected database and computes what migrations are pending.
   106  It then attempts to apply the pending migration files in the correct order onto the database. 
   107  The first argument denotes the maximum number of migration files to apply.
   108  As a safety measure 'atlas migrate apply' will abort with an error, if:
   109    - the migration directory is not in sync with the 'atlas.sum' file
   110    - the migration and database history do not match each other
   111  
   112  If run with the "--dry-run" flag, atlas will not execute any SQL.`,
   113  			Example: `  atlas migrate apply -u mysql://user:pass@localhost:3306/dbname
   114    atlas migrate apply --dir file:///path/to/migration/directory --url mysql://user:pass@localhost:3306/dbname 1
   115    atlas migrate apply --env dev 1
   116    atlas migrate apply --dry-run --env dev 1`,
   117  			Args: cobra.MaximumNArgs(1),
   118  			RunE: func(cmd *cobra.Command, args []string) (err error) {
   119  				switch {
   120  				case GlobalFlags.SelectedEnv == "":
   121  					if err := migrateFlagsFromConfig(cmd); err != nil {
   122  						return err
   123  					}
   124  					return migrateApplyRun(cmd, args, flags, &MigrateReport{}) // nop reporter
   125  				default:
   126  					project, envs, err := EnvByName(GlobalFlags.SelectedEnv, WithInput(GlobalFlags.Vars))
   127  					if err != nil {
   128  						return err
   129  					}
   130  					set := NewReportProvider(project, envs)
   131  					defer set.Flush(cmd)
   132  					return cmdEnvsRun(envs, setMigrateEnvFlags, cmd, func(env *Env) error {
   133  						return migrateApplyRun(cmd, args, flags, set.ReportFor(flags, env))
   134  					})
   135  				}
   136  			},
   137  		}
   138  	)
   139  	cmd.Flags().SortFlags = false
   140  	addFlagURL(cmd.Flags(), &flags.url)
   141  	addFlagDirURL(cmd.Flags(), &flags.dirURL)
   142  	addFlagLog(cmd.Flags(), &flags.logFormat)
   143  	addFlagFormat(cmd.Flags(), &flags.logFormat)
   144  	addFlagRevisionSchema(cmd.Flags(), &flags.revisionSchema)
   145  	addFlagDryRun(cmd.Flags(), &flags.dryRun)
   146  	addFlagLockTimeout(cmd.Flags(), &flags.lockTimeout)
   147  	cmd.Flags().StringVarP(&flags.fromVersion, flagFrom, "", "", "calculate pending files from the given version (including it)")
   148  	cmd.Flags().StringVarP(&flags.baselineVersion, flagBaseline, "", "", "start the first migration after the given baseline version")
   149  	cmd.Flags().StringVarP(&flags.txMode, flagTxMode, "", txModeFile, "set transaction mode [none, file, all]")
   150  	cmd.Flags().BoolVarP(&flags.allowDirty, flagAllowDirty, "", false, "allow start working on a non-clean database")
   151  	cmd.MarkFlagsMutuallyExclusive(flagFrom, flagBaseline)
   152  	cmd.MarkFlagsMutuallyExclusive(flagLog, flagFormat)
   153  	return cmd
   154  }
   155  
   156  // migrateApplyCmd represents the 'atlas migrate apply' subcommand.
   157  func migrateApplyRun(cmd *cobra.Command, args []string, flags migrateApplyFlags, mr *MigrateReport) (err error) {
   158  	var (
   159  		count int
   160  		ctx   = cmd.Context()
   161  	)
   162  	if len(args) > 0 {
   163  		if count, err = strconv.Atoi(args[0]); err != nil {
   164  			return err
   165  		}
   166  		if count < 1 {
   167  			return fmt.Errorf("cannot apply '%d' migration files", count)
   168  		}
   169  	}
   170  	// Open and validate the migration directory.
   171  	migrationDir, err := dir(flags.dirURL, false)
   172  	if err != nil {
   173  		return err
   174  	}
   175  	if err := migrate.Validate(migrationDir); err != nil {
   176  		printChecksumError(cmd)
   177  		return err
   178  	}
   179  	// Open a client to the database.
   180  	if flags.url == "" {
   181  		return errors.New(`required flag "url" not set`)
   182  	}
   183  	client, err := sqlclient.Open(ctx, flags.url)
   184  	if err != nil {
   185  		return err
   186  	}
   187  	defer client.Close()
   188  	// Prevent usage printing after input validation.
   189  	cmd.SilenceUsage = true
   190  	// Acquire a lock.
   191  	if l, ok := client.Driver.(schema.Locker); ok {
   192  		unlock, err := l.Lock(ctx, applyLockValue, flags.lockTimeout)
   193  		if err != nil {
   194  			return fmt.Errorf("acquiring database lock: %w", err)
   195  		}
   196  		// If unlocking fails notify the user about it.
   197  		defer func() { cobra.CheckErr(unlock()) }()
   198  	}
   199  	if err := checkRevisionSchemaClarity(cmd, client, flags.revisionSchema); err != nil {
   200  		return err
   201  	}
   202  	var rrw migrate.RevisionReadWriter
   203  	if rrw, err = entRevisions(ctx, client, flags.revisionSchema); err != nil {
   204  		return err
   205  	}
   206  	if err := rrw.(*cmdmigrate.EntRevisions).Migrate(ctx); err != nil {
   207  		return err
   208  	}
   209  	// Setup reporting info.
   210  	report := cmdlog.NewMigrateApply(client, migrationDir)
   211  	mr.Init(client, report, rrw.(*cmdmigrate.EntRevisions))
   212  	// If cloud reporting is enabled, and we cannot obtain the current
   213  	// target identifier, abort and report it to the user.
   214  	if err := mr.RecordTargetID(cmd.Context()); err != nil {
   215  		return err
   216  	}
   217  	// Determine pending files.
   218  	opts := append(flags.migrateOptions(), migrate.WithOperatorVersion(operatorVersion()), migrate.WithLogger(report))
   219  	ex, err := migrate.NewExecutor(client.Driver, migrationDir, rrw, opts...)
   220  	if err != nil {
   221  		return err
   222  	}
   223  	pending, err := ex.Pending(ctx)
   224  	if err != nil && !errors.Is(err, migrate.ErrNoPendingFiles) {
   225  		return err
   226  	}
   227  	noPending := errors.Is(err, migrate.ErrNoPendingFiles)
   228  	// Get the pending files before obtaining applied revisions,
   229  	// as the Executor may write a baseline revision in the table.
   230  	applied, err := rrw.ReadRevisions(ctx)
   231  	if err != nil {
   232  		return err
   233  	}
   234  	if noPending {
   235  		migrate.LogNoPendingFiles(report, applied)
   236  		return mr.Done(cmd, flags)
   237  	}
   238  	if l := len(pending); count == 0 || count >= l {
   239  		// Cannot apply more than len(pending) migration files.
   240  		count = l
   241  	}
   242  	pending = pending[:count]
   243  	migrate.LogIntro(report, applied, pending)
   244  	var (
   245  		mux = tx{
   246  			dryRun: flags.dryRun,
   247  			mode:   flags.txMode,
   248  			schema: flags.revisionSchema,
   249  			c:      client,
   250  			rrw:    rrw,
   251  		}
   252  		drv migrate.Driver
   253  	)
   254  	for _, f := range pending {
   255  		if drv, rrw, err = mux.driverFor(ctx, f); err != nil {
   256  			break
   257  		}
   258  		if ex, err = migrate.NewExecutor(drv, migrationDir, rrw, opts...); err != nil {
   259  			return fmt.Errorf("unexpected exectuor creation error: %w", err)
   260  		}
   261  		if err = mux.mayRollback(ex.Execute(ctx, f)); err != nil {
   262  			break
   263  		}
   264  		if err = mux.mayCommit(); err != nil {
   265  			break
   266  		}
   267  	}
   268  	if err == nil {
   269  		if err = mux.commit(); err == nil {
   270  			report.Log(migrate.LogDone{})
   271  		}
   272  	}
   273  	if err != nil {
   274  		report.Error = err.Error()
   275  	}
   276  	return errors.Join(err, mr.Done(cmd, flags))
   277  }
   278  
   279  type (
   280  	// MigrateReport responsible for reporting 'migrate apply' reports.
   281  	MigrateReport struct {
   282  		id     string // target id
   283  		env    *Env   // nil, if no env set
   284  		client *sqlclient.Client
   285  		log    *cmdlog.MigrateApply
   286  		rrw    *cmdmigrate.EntRevisions
   287  		done   func(*cloudapi.ReportMigrationInput)
   288  	}
   289  	// MigrateReportSet is a set of reports.
   290  	MigrateReportSet struct {
   291  		cloudapi.ReportMigrationSetInput
   292  		client *cloudapi.Client
   293  		done   int // number of done migrations
   294  	}
   295  )
   296  
   297  // NewReportProvider returns a new ReporterProvider.
   298  func NewReportProvider(project *Project, envs []*Env) *MigrateReportSet {
   299  	s := &MigrateReportSet{
   300  		client: project.cfg.Client,
   301  		ReportMigrationSetInput: cloudapi.ReportMigrationSetInput{
   302  			ID:        uuid.NewString(),
   303  			StartTime: time.Now(),
   304  			Planned:   len(envs),
   305  		},
   306  	}
   307  	s.Step("Start migration for %d targets", len(envs))
   308  	for _, e := range envs {
   309  		s.StepLog(s.RedactedURL(e.URL))
   310  	}
   311  	return s
   312  }
   313  
   314  // RedactedURL returns the redacted URL of the given environment at index i.
   315  func (*MigrateReportSet) RedactedURL(u string) string {
   316  	u, err := cloudapi.RedactedURL(u)
   317  	if err != nil {
   318  		return fmt.Sprintf("Error: redacting URL: %v", err)
   319  	}
   320  	return u
   321  }
   322  
   323  // Step starts a new reporting step.
   324  func (s *MigrateReportSet) Step(format string, args ...interface{}) {
   325  	if len(s.Log) > 0 && s.Log[len(s.Log)-1].EndTime.IsZero() {
   326  		s.Log[len(s.Log)-1].EndTime = time.Now()
   327  	}
   328  	s.Log = append(s.Log, cloudapi.ReportStep{
   329  		StartTime: time.Now(),
   330  		Text:      fmt.Sprintf(format, args...),
   331  	})
   332  }
   333  
   334  // StepLog logs a line to the current reporting step.
   335  func (s *MigrateReportSet) StepLog(format string, args ...interface{}) {
   336  	if len(s.Log) == 0 {
   337  		s.Step("Unnamed step") // Unexpected.
   338  	}
   339  	s.Log[len(s.Log)-1].Log = append(s.Log[len(s.Log)-1].Log, cloudapi.ReportStepLog{
   340  		Text: fmt.Sprintf(format, args...),
   341  	})
   342  }
   343  
   344  // StepLogError logs a line to the current reporting step.
   345  func (s *MigrateReportSet) StepLogError(text string) {
   346  	if !strings.HasPrefix(text, "Error") {
   347  		text = "Error: " + text
   348  	}
   349  	s.StepLog(text)
   350  	s.Error = &text
   351  	s.Log[len(s.Log)-1].Error = true
   352  }
   353  
   354  // ReportFor returns a new MigrateReport for the given environment.
   355  func (s *MigrateReportSet) ReportFor(flags migrateApplyFlags, e *Env) *MigrateReport {
   356  	s.Step("Run migration: %d", s.done+1)
   357  	s.StepLog("Target URL: %s", s.RedactedURL(e.URL))
   358  	s.StepLog("Migration directory: %s", s.RedactedURL(flags.dirURL))
   359  	return &MigrateReport{
   360  		env: e,
   361  		done: func(r *cloudapi.ReportMigrationInput) {
   362  			s.done++
   363  			s.Log[len(s.Log)-1].EndTime = time.Now()
   364  			if r.Error != nil && *r.Error != "" {
   365  				s.StepLogError(*r.Error)
   366  			}
   367  			s.Completed = append(s.Completed, *r)
   368  		},
   369  	}
   370  }
   371  
   372  // Flush report the migration deployment to the cloud.
   373  // The current implementation is simplistic and sends each
   374  // report separately without marking them as part of a group.
   375  //
   376  // Note that reporting errors are logged, but not cause Atlas to fail.
   377  func (s *MigrateReportSet) Flush(cmd *cobra.Command) {
   378  	var err error
   379  	switch {
   380  	// Skip reporting if set is empty,
   381  	// or there is no cloud connectivity.
   382  	case s.Planned == 0, s.client == nil:
   383  		return
   384  	// Single migration.
   385  	case s.Planned == 1 && len(s.Completed) == 1:
   386  		err = s.client.ReportMigration(cmd.Context(), s.Completed[0])
   387  	// Multi environment migration (e.g., multi-tenancy).
   388  	case s.Planned > 1:
   389  		s.EndTime = time.Now()
   390  		err = s.client.ReportMigrationSet(cmd.Context(), s.ReportMigrationSetInput)
   391  	}
   392  	if err != nil {
   393  		txt := fmt.Sprintf("Error: %s", strings.TrimRight(err.Error(), "\n"))
   394  		// Ensure errors are printed in new lines.
   395  		if cmd.Flags().Changed(flagFormat) {
   396  			txt = "\n" + txt
   397  		}
   398  		cmd.PrintErrln(txt)
   399  	}
   400  }
   401  
   402  // Init the report if the necessary dependencies.
   403  func (r *MigrateReport) Init(c *sqlclient.Client, l *cmdlog.MigrateApply, rrw *cmdmigrate.EntRevisions) {
   404  	r.client, r.log, r.rrw = c, l, rrw
   405  }
   406  
   407  // RecordTargetID asks the revisions-table to allow or provide
   408  // the target identifier if cloud reporting is enabled.
   409  func (r *MigrateReport) RecordTargetID(ctx context.Context) error {
   410  	if r.CloudEnabled() {
   411  		id, err := r.rrw.ID(ctx, operatorVersion())
   412  		if err != nil {
   413  			return err
   414  		}
   415  		r.id = id
   416  	}
   417  	return nil
   418  }
   419  
   420  // Done closes and flushes this report.
   421  func (r *MigrateReport) Done(cmd *cobra.Command, flags migrateApplyFlags) error {
   422  	if !r.CloudEnabled() {
   423  		return logApply(cmd, cmd.OutOrStdout(), flags, r.log)
   424  	}
   425  	var (
   426  		ver  string
   427  		clog bytes.Buffer
   428  		err  = logApply(cmd, io.MultiWriter(cmd.OutOrStdout(), &clog), flags, r.log)
   429  	)
   430  	switch rev, err1 := r.rrw.CurrentRevision(cmd.Context()); {
   431  	case ent.IsNotFound(err1):
   432  	case err1 != nil:
   433  		return errors.Join(err, err1)
   434  	default:
   435  		ver = rev.Version
   436  	}
   437  	r.done(&cloudapi.ReportMigrationInput{
   438  		ProjectName:  r.env.cfg.Project,
   439  		EnvName:      r.env.Name,
   440  		DirName:      path.Base(flags.dirURL),
   441  		AtlasVersion: operatorVersion(),
   442  		Target: cloudapi.DeployedTargetInput{
   443  			ID:     r.id,
   444  			Schema: r.client.URL.Schema,
   445  			URL:    r.client.URL.Redacted(),
   446  		},
   447  		StartTime:      r.log.Start,
   448  		EndTime:        r.log.End,
   449  		FromVersion:    r.log.Current,
   450  		ToVersion:      r.log.Target,
   451  		CurrentVersion: ver,
   452  		Error: func() *string {
   453  			if r.log.Error != "" {
   454  				return &r.log.Error
   455  			}
   456  			return nil
   457  		}(),
   458  		Files: func() []cloudapi.DeployedFileInput {
   459  			files := make([]cloudapi.DeployedFileInput, len(r.log.Applied))
   460  			for i, f := range r.log.Applied {
   461  				files[i] = cloudapi.DeployedFileInput{
   462  					Name:      f.Name(),
   463  					Content:   string(f.Bytes()),
   464  					StartTime: f.Start,
   465  					EndTime:   f.End,
   466  					Skipped:   f.Skipped,
   467  					Applied:   len(f.Applied),
   468  					Error:     (*cloudapi.StmtErrorInput)(f.Error),
   469  				}
   470  			}
   471  			return files
   472  		}(),
   473  		Log: clog.String(),
   474  	})
   475  	return err
   476  }
   477  
   478  // CloudEnabled reports if cloud reporting is enabled.
   479  func (r *MigrateReport) CloudEnabled() bool {
   480  	return r.env != nil && r.env.cfg != nil && r.env.cfg.Client != nil && r.env.cfg.Project != ""
   481  }
   482  
   483  func logApply(cmd *cobra.Command, w io.Writer, flags migrateApplyFlags, r *cmdlog.MigrateApply) error {
   484  	var (
   485  		err error
   486  		f   = cmdlog.MigrateApplyTemplate
   487  	)
   488  	if v := flags.logFormat; v != "" {
   489  		f, err = template.New("format").Funcs(cmdlog.ApplyTemplateFuncs).Parse(v)
   490  		if err != nil {
   491  			return fmt.Errorf("parse format: %w", err)
   492  		}
   493  	}
   494  	if err = f.Execute(w, r); err != nil {
   495  		return fmt.Errorf("execute log template: %w", err)
   496  	}
   497  	// In case a custom logging was configured, avoid reporting errors twice.
   498  	// For example, printing error lines may break parsing the JSON output.
   499  	cmd.SilenceErrors = flags.logFormat != ""
   500  	return nil
   501  }
   502  
   503  type migrateDiffFlags struct {
   504  	edit              bool
   505  	desiredURLs       []string
   506  	dirURL, dirFormat string
   507  	devURL            string
   508  	schemas           []string
   509  	lockTimeout       time.Duration
   510  	format            string
   511  	revisionSchema    string // revision schema name
   512  	qualifier         string // optional table qualifier
   513  }
   514  
   515  // migrateDiffCmd represents the 'atlas migrate diff' subcommand.
   516  func migrateDiffCmd() *cobra.Command {
   517  	var (
   518  		flags migrateDiffFlags
   519  		cmd   = &cobra.Command{
   520  			Use:   "diff [flags] [name]",
   521  			Short: "Compute the diff between the migration directory and a desired state and create a new migration file.",
   522  			Long: `'atlas migrate diff' uses the dev-database to re-run all migration files in the migration directory, compares
   523  it to a given desired state and create a new migration file containing SQL statements to migrate the migration
   524  directory state to the desired schema. The desired state can be another connected database or an HCL file.`,
   525  			Example: `  atlas migrate diff --dev-url mysql://user:pass@localhost:3306/dev --to file://schema.hcl
   526    atlas migrate diff --dev-url mysql://user:pass@localhost:3306/dev --to file://atlas.hcl add_users_table
   527    atlas migrate diff --dev-url mysql://user:pass@localhost:3306/dev --to mysql://user:pass@localhost:3306/dbname
   528    atlas migrate diff --env dev --format '{{ sql . "  " }}'`,
   529  			Args: cobra.MaximumNArgs(1),
   530  			PreRunE: func(cmd *cobra.Command, args []string) error {
   531  				if err := migrateFlagsFromConfig(cmd); err != nil {
   532  					return err
   533  				}
   534  				if err := dirFormatBC(flags.dirFormat, &flags.dirURL); err != nil {
   535  					return err
   536  				}
   537  				return checkDir(cmd, flags.dirURL, true)
   538  			},
   539  			RunE: func(cmd *cobra.Command, args []string) error {
   540  				env, err := selectEnv(cmd)
   541  				if err != nil {
   542  					return err
   543  				}
   544  				return migrateDiffRun(cmd, args, flags, env)
   545  			},
   546  		}
   547  	)
   548  	cmd.Flags().SortFlags = false
   549  	addFlagToURLs(cmd.Flags(), &flags.desiredURLs)
   550  	addFlagDevURL(cmd.Flags(), &flags.devURL)
   551  	addFlagDirURL(cmd.Flags(), &flags.dirURL)
   552  	addFlagDirFormat(cmd.Flags(), &flags.dirFormat)
   553  	addFlagRevisionSchema(cmd.Flags(), &flags.revisionSchema)
   554  	addFlagSchemas(cmd.Flags(), &flags.schemas)
   555  	addFlagLockTimeout(cmd.Flags(), &flags.lockTimeout)
   556  	addFlagFormat(cmd.Flags(), &flags.format)
   557  	cmd.Flags().StringVar(&flags.qualifier, flagQualifier, "", "qualify tables with custom qualifier when working on a single schema")
   558  	cmd.Flags().BoolVarP(&flags.edit, flagEdit, "", false, "edit the generated migration file(s)")
   559  	cobra.CheckErr(cmd.MarkFlagRequired(flagTo))
   560  	cobra.CheckErr(cmd.MarkFlagRequired(flagDevURL))
   561  	return cmd
   562  }
   563  
   564  func migrateDiffRun(cmd *cobra.Command, args []string, flags migrateDiffFlags, env *Env) error {
   565  	ctx := cmd.Context()
   566  	dev, err := sqlclient.Open(ctx, flags.devURL)
   567  	if err != nil {
   568  		return err
   569  	}
   570  	defer dev.Close()
   571  	// Acquire a lock.
   572  	if l, ok := dev.Driver.(schema.Locker); ok {
   573  		unlock, err := l.Lock(ctx, "atlas_migrate_diff", flags.lockTimeout)
   574  		if err != nil {
   575  			return fmt.Errorf("acquiring database lock: %w", err)
   576  		}
   577  		// If unlocking fails notify the user about it.
   578  		defer func() { cobra.CheckErr(unlock()) }()
   579  	}
   580  	// Open the migration directory.
   581  	u, err := url.Parse(flags.dirURL)
   582  	if err != nil {
   583  		return err
   584  	}
   585  	dir, err := dirURL(u, false)
   586  	if err != nil {
   587  		return err
   588  	}
   589  	if flags.edit {
   590  		dir = &editDir{dir}
   591  	}
   592  	var name, indent string
   593  	if len(args) > 0 {
   594  		name = args[0]
   595  	}
   596  	f, err := formatter(u)
   597  	if err != nil {
   598  		return err
   599  	}
   600  	if f, indent, err = mayIndent(u, f, flags.format); err != nil {
   601  		return err
   602  	}
   603  	// If there is a state-loader that requires a custom
   604  	// 'migrate diff' handling, offload it the work.
   605  	if d, ok := cmdext.States.Differ(flags.desiredURLs); ok {
   606  		err := d.MigrateDiff(ctx, &cmdext.MigrateDiffOptions{
   607  			To:      flags.desiredURLs,
   608  			Name:    name,
   609  			Indent:  indent,
   610  			Dir:     dir,
   611  			Dev:     dev,
   612  			Options: env.DiffOptions(),
   613  		})
   614  		return maskNoPlan(cmd, err)
   615  	}
   616  	// Get a state reader for the desired state.
   617  	desired, err := stateReader(ctx, &stateReaderConfig{
   618  		urls:    flags.desiredURLs,
   619  		dev:     dev,
   620  		client:  dev,
   621  		schemas: flags.schemas,
   622  		vars:    GlobalFlags.Vars,
   623  	})
   624  	if err != nil {
   625  		return err
   626  	}
   627  	defer desired.Close()
   628  	opts := []migrate.PlannerOption{
   629  		migrate.PlanFormat(f),
   630  		migrate.PlanWithIndent(indent),
   631  		migrate.PlanWithDiffOptions(env.DiffOptions()...),
   632  	}
   633  	if dev.URL.Schema != "" {
   634  		// Disable tables qualifier in schema-mode.
   635  		opts = append(opts, migrate.PlanWithSchemaQualifier(flags.qualifier))
   636  	}
   637  	// Plan the changes and create a new migration file.
   638  	pl := migrate.NewPlanner(dev.Driver, dir, opts...)
   639  	plan, err := func() (*migrate.Plan, error) {
   640  		if dev.URL.Schema != "" {
   641  			return pl.PlanSchema(ctx, name, desired.StateReader)
   642  		}
   643  		return pl.Plan(ctx, name, desired.StateReader)
   644  	}()
   645  	var cerr *migrate.NotCleanError
   646  	switch {
   647  	case errors.As(err, &cerr) && dev.URL.Schema == "" && desired.schema != "":
   648  		return fmt.Errorf("dev database is not clean (%s). Add a schema to the URL to limit the scope of the connection", cerr.Reason)
   649  	case err != nil:
   650  		return maskNoPlan(cmd, err)
   651  	default:
   652  		// Write the plan to a new file.
   653  		return pl.WritePlan(plan)
   654  	}
   655  }
   656  
   657  func mayIndent(dir *url.URL, f migrate.Formatter, format string) (migrate.Formatter, string, error) {
   658  	if format == "" {
   659  		return f, "", nil
   660  	}
   661  	reject := errors.New(`'sql' can only be used to indent statements`)
   662  	t, err := template.New("format").
   663  		// The "sql" is a dummy function to detect if the
   664  		// template was used to indent the SQL statements.
   665  		Funcs(template.FuncMap{"sql": func(...any) (string, error) { return "", reject }}).
   666  		Parse(format)
   667  	if err != nil {
   668  		return nil, "", fmt.Errorf("parse format: %w", err)
   669  	}
   670  	indent, ok := func() (string, bool) {
   671  		if len(t.Tree.Root.Nodes) != 1 {
   672  			return "", false
   673  		}
   674  		n, ok := t.Tree.Root.Nodes[0].(*parse.ActionNode)
   675  		if !ok || len(n.Pipe.Cmds) != 1 || len(n.Pipe.Cmds[0].Args) < 2 || len(n.Pipe.Cmds[0].Args) > 3 {
   676  			return "", false
   677  		}
   678  		args := n.Pipe.Cmds[0].Args
   679  		if args[0].String() != "sql" || args[1].String() != "." && args[1].String() != "$" {
   680  			return "", false
   681  		}
   682  		d := `""` // empty string as arg.
   683  		if len(args) == 3 {
   684  			d = args[2].String()
   685  		}
   686  		return d, true
   687  	}()
   688  	if ok {
   689  		if indent, err = strconv.Unquote(indent); err != nil {
   690  			return nil, "", fmt.Errorf("parse indent: %w", err)
   691  		}
   692  		return f, indent, nil
   693  	}
   694  	// If the template is not an indent, it cannot contain the "sql" function.
   695  	if err := t.Execute(io.Discard, &migrate.Plan{}); err != nil && errors.Is(err, reject) {
   696  		return nil, "", fmt.Errorf("%v. got: %v", reject, t.Root.String())
   697  	}
   698  	tfs := f.(migrate.TemplateFormatter)
   699  	if len(tfs) != 1 {
   700  		return nil, "", fmt.Errorf("cannot use format with: %q", dir.Query().Get("format"))
   701  	}
   702  	return migrate.TemplateFormatter{{N: tfs[0].N, C: t}}, "", nil
   703  }
   704  
   705  // maskNoPlan masks ErrNoPlan errors.
   706  func maskNoPlan(cmd *cobra.Command, err error) error {
   707  	if errors.Is(err, migrate.ErrNoPlan) {
   708  		cmd.Println("The migration directory is synced with the desired state, no changes to be made")
   709  		return nil
   710  	}
   711  	return err
   712  }
   713  
   714  type migrateHashFlags struct{ dirURL, dirFormat string }
   715  
   716  // migrateHashCmd represents the 'atlas migrate hash' subcommand.
   717  func migrateHashCmd() *cobra.Command {
   718  	var (
   719  		flags migrateHashFlags
   720  		cmd   = &cobra.Command{
   721  			Use:   "hash [flags]",
   722  			Short: "Hash (re-)creates an integrity hash file for the migration directory.",
   723  			Long: `'atlas migrate hash' computes the integrity hash sum of the migration directory and stores it in the atlas.sum file.
   724  This command should be used whenever a manual change in the migration directory was made.`,
   725  			Example: `  atlas migrate hash`,
   726  			PreRunE: func(cmd *cobra.Command, args []string) error {
   727  				if err := migrateFlagsFromConfig(cmd); err != nil {
   728  					return err
   729  				}
   730  				return dirFormatBC(flags.dirFormat, &flags.dirURL)
   731  			},
   732  			RunE: func(cmd *cobra.Command, args []string) error {
   733  				dir, err := dir(flags.dirURL, false)
   734  				if err != nil {
   735  					return err
   736  				}
   737  				sum, err := dir.Checksum()
   738  				if err != nil {
   739  					return err
   740  				}
   741  				return migrate.WriteSumFile(dir, sum)
   742  			},
   743  		}
   744  	)
   745  	addFlagDirURL(cmd.Flags(), &flags.dirURL)
   746  	addFlagDirFormat(cmd.Flags(), &flags.dirFormat)
   747  	cmd.Flags().Bool("force", false, "")
   748  	cobra.CheckErr(cmd.Flags().MarkDeprecated("force", "you can safely omit it."))
   749  	return cmd
   750  }
   751  
   752  type migrateImportFlags struct{ fromURL, toURL, dirFormat string }
   753  
   754  // migrateImportCmd represents the 'atlas migrate import' subcommand.
   755  func migrateImportCmd() *cobra.Command {
   756  	var (
   757  		flags migrateImportFlags
   758  		cmd   = &cobra.Command{
   759  			Use:     "import [flags]",
   760  			Short:   "Import a migration directory from another migration management tool to the Atlas format.",
   761  			Example: `  atlas migrate import --from file:///path/to/source/directory?format=liquibase --to file:///path/to/migration/directory`,
   762  			// Validate the source directory. Consider a directory with no sum file
   763  			// valid, since it might be an import from an existing project.
   764  			PreRunE: func(cmd *cobra.Command, _ []string) error {
   765  				if err := migrateFlagsFromConfig(cmd); err != nil {
   766  					return err
   767  				}
   768  				if err := dirFormatBC(flags.dirFormat, &flags.fromURL); err != nil {
   769  					return err
   770  				}
   771  				d, err := dir(flags.fromURL, false)
   772  				if err != nil {
   773  					return err
   774  				}
   775  				if err = migrate.Validate(d); err != nil && !errors.Is(err, migrate.ErrChecksumNotFound) {
   776  					printChecksumError(cmd)
   777  					return err
   778  				}
   779  				return nil
   780  			},
   781  			RunE: func(cmd *cobra.Command, args []string) error {
   782  				return migrateImportRun(cmd, args, flags)
   783  			},
   784  		}
   785  	)
   786  	cmd.Flags().SortFlags = false
   787  	addFlagDirURL(cmd.Flags(), &flags.fromURL, flagFrom)
   788  	addFlagDirURL(cmd.Flags(), &flags.toURL, flagTo)
   789  	addFlagDirFormat(cmd.Flags(), &flags.dirFormat)
   790  	return cmd
   791  }
   792  
   793  func migrateImportRun(cmd *cobra.Command, _ []string, flags migrateImportFlags) error {
   794  	p, err := url.Parse(flags.fromURL)
   795  	if err != nil {
   796  		return err
   797  	}
   798  	if f := p.Query().Get("format"); f == "" || f == formatAtlas {
   799  		return fmt.Errorf("cannot import a migration directory already in %q format", formatAtlas)
   800  	}
   801  	src, err := dir(flags.fromURL, false)
   802  	if err != nil {
   803  		return err
   804  	}
   805  	trgt, err := dir(flags.toURL, true)
   806  	if err != nil {
   807  		return err
   808  	}
   809  	// Target must be empty.
   810  	ff, err := trgt.Files()
   811  	switch {
   812  	case err != nil:
   813  		return err
   814  	case len(ff) != 0:
   815  		return errors.New("target migration directory must be empty")
   816  	}
   817  	ff, err = src.Files()
   818  	switch {
   819  	case err != nil:
   820  		return err
   821  	case len(ff) == 0:
   822  		fmt.Fprint(cmd.OutOrStderr(), "nothing to import")
   823  		cmd.SilenceUsage = true
   824  		return nil
   825  	}
   826  	// Fix version numbers for Flyway repeatable migrations.
   827  	if _, ok := src.(*sqltool.FlywayDir); ok {
   828  		sqltool.SetRepeatableVersion(ff)
   829  	}
   830  	// Extract the statements for each of the migration files,
   831  	// add them to a plan to format with the DefaultFormatter.
   832  	for _, f := range ff {
   833  		stmts, err := f.StmtDecls()
   834  		if err != nil {
   835  			return err
   836  		}
   837  		plan := &migrate.Plan{
   838  			Version: f.Version(),
   839  			Name:    f.Desc(),
   840  			Changes: make([]*migrate.Change, len(stmts)),
   841  		}
   842  		var buf strings.Builder
   843  		for i, s := range stmts {
   844  			for _, c := range s.Comments {
   845  				buf.WriteString(c)
   846  				if !strings.HasSuffix(c, "\n") {
   847  					buf.WriteString("\n")
   848  				}
   849  			}
   850  			buf.WriteString(strings.TrimSuffix(s.Text, ";"))
   851  			plan.Changes[i] = &migrate.Change{Cmd: buf.String()}
   852  			buf.Reset()
   853  		}
   854  		files, err := migrate.DefaultFormatter.Format(plan)
   855  		if err != nil {
   856  			return err
   857  		}
   858  		for _, f := range files {
   859  			if err := trgt.WriteFile(f.Name(), f.Bytes()); err != nil {
   860  				return err
   861  			}
   862  		}
   863  	}
   864  	sum, err := trgt.Checksum()
   865  	if err != nil {
   866  		return err
   867  	}
   868  	return migrate.WriteSumFile(trgt, sum)
   869  }
   870  
   871  type migrateLintFlags struct {
   872  	dirURL, dirFormat string
   873  	devURL            string
   874  	logFormat         string
   875  	latest            uint
   876  	gitBase, gitDir   string
   877  }
   878  
   879  // migrateLintCmd represents the 'atlas migrate lint' subcommand.
   880  func migrateLintCmd() *cobra.Command {
   881  	var (
   882  		flags migrateLintFlags
   883  		cmd   = &cobra.Command{
   884  			Use:   "lint [flags]",
   885  			Short: "Run analysis on the migration directory",
   886  			Example: `  atlas migrate lint --env dev
   887    atlas migrate lint --dir file:///path/to/migration/directory --dev-url mysql://root:pass@localhost:3306 --latest 1
   888    atlas migrate lint --dir file:///path/to/migration/directory --dev-url mysql://root:pass@localhost:3306 --git-base master
   889    atlas migrate lint --dir file:///path/to/migration/directory --dev-url mysql://root:pass@localhost:3306 --format '{{ json .Files }}'`,
   890  			PreRunE: func(cmd *cobra.Command, args []string) error {
   891  				if err := migrateFlagsFromConfig(cmd); err != nil {
   892  					return err
   893  				}
   894  				return dirFormatBC(flags.dirFormat, &flags.dirURL)
   895  			},
   896  			RunE: func(cmd *cobra.Command, args []string) error {
   897  				return migrateLintRun(cmd, args, flags)
   898  			},
   899  		}
   900  	)
   901  	cmd.Flags().SortFlags = false
   902  	addFlagDevURL(cmd.Flags(), &flags.devURL)
   903  	addFlagDirURL(cmd.Flags(), &flags.dirURL)
   904  	addFlagDirFormat(cmd.Flags(), &flags.dirFormat)
   905  	addFlagLog(cmd.Flags(), &flags.logFormat)
   906  	addFlagFormat(cmd.Flags(), &flags.logFormat)
   907  	cmd.Flags().UintVarP(&flags.latest, flagLatest, "", 0, "run analysis on the latest N migration files")
   908  	cmd.Flags().StringVarP(&flags.gitBase, flagGitBase, "", "", "run analysis against the base Git branch")
   909  	cmd.Flags().StringVarP(&flags.gitDir, flagGitDir, "", ".", "path to the repository working directory")
   910  	cobra.CheckErr(cmd.MarkFlagRequired(flagDevURL))
   911  	cmd.MarkFlagsMutuallyExclusive(flagLog, flagFormat)
   912  	return cmd
   913  }
   914  
   915  func migrateLintRun(cmd *cobra.Command, _ []string, flags migrateLintFlags) error {
   916  	dev, err := sqlclient.Open(cmd.Context(), flags.devURL)
   917  	if err != nil {
   918  		return err
   919  	}
   920  	defer dev.Close()
   921  	dir, err := dir(flags.dirURL, false)
   922  	if err != nil {
   923  		return err
   924  	}
   925  	var detect lint.ChangeDetector
   926  	switch {
   927  	case flags.latest == 0 && flags.gitBase == "":
   928  		return fmt.Errorf("--%s or --%s is required", flagLatest, flagGitBase)
   929  	case flags.latest > 0 && flags.gitBase != "":
   930  		return fmt.Errorf("--%s and --%s are mutually exclusive", flagLatest, flagGitBase)
   931  	case flags.latest > 0:
   932  		detect = lint.LatestChanges(dir, int(flags.latest))
   933  	case flags.gitBase != "":
   934  		detect, err = lint.NewGitChangeDetector(
   935  			dir,
   936  			lint.WithWorkDir(flags.gitDir),
   937  			lint.WithBase(flags.gitBase),
   938  			lint.WithMigrationsPath(dir.(interface{ Path() string }).Path()),
   939  		)
   940  		if err != nil {
   941  			return err
   942  		}
   943  	}
   944  	format := lint.DefaultTemplate
   945  	if f := flags.logFormat; f != "" {
   946  		format, err = template.New("format").Funcs(lint.TemplateFuncs).Parse(f)
   947  		if err != nil {
   948  			return fmt.Errorf("parse format: %w", err)
   949  		}
   950  	}
   951  	env, err := selectEnv(cmd)
   952  	if err != nil {
   953  		return err
   954  	}
   955  	az, err := sqlcheck.AnalyzerFor(dev.Name, env.Lint.Remain())
   956  	if err != nil {
   957  		return err
   958  	}
   959  	r := &lint.Runner{
   960  		Dev:            dev,
   961  		Dir:            dir,
   962  		ChangeDetector: detect,
   963  		ReportWriter: &lint.TemplateWriter{
   964  			T: format,
   965  			W: cmd.OutOrStdout(),
   966  		},
   967  		Analyzers: az,
   968  	}
   969  	err = r.Run(cmd.Context())
   970  	// Print the error in case it was not printed before.
   971  	cmd.SilenceErrors = errors.As(err, &lint.SilentError{})
   972  	cmd.SilenceUsage = cmd.SilenceErrors
   973  	return err
   974  }
   975  
   976  type migrateNewFlags struct {
   977  	edit      bool
   978  	dirURL    string
   979  	dirFormat string
   980  }
   981  
   982  // migrateNewCmd represents the 'atlas migrate new' subcommand.
   983  func migrateNewCmd() *cobra.Command {
   984  	var (
   985  		flags migrateNewFlags
   986  		cmd   = &cobra.Command{
   987  			Use:     "new [flags] [name]",
   988  			Short:   "Creates a new empty migration file in the migration directory.",
   989  			Long:    `'atlas migrate new' creates a new migration according to the configured formatter without any statements in it.`,
   990  			Example: `  atlas migrate new my-new-migration`,
   991  			Args:    cobra.MaximumNArgs(1),
   992  			PreRunE: func(cmd *cobra.Command, _ []string) error {
   993  				if err := migrateFlagsFromConfig(cmd); err != nil {
   994  					return err
   995  				}
   996  				if err := dirFormatBC(flags.dirFormat, &flags.dirURL); err != nil {
   997  					return err
   998  				}
   999  				return checkDir(cmd, flags.dirURL, true)
  1000  			},
  1001  			RunE: func(cmd *cobra.Command, args []string) error {
  1002  				return migrateNewRun(cmd, args, flags)
  1003  			},
  1004  		}
  1005  	)
  1006  	cmd.Flags().SortFlags = false
  1007  	addFlagDirURL(cmd.Flags(), &flags.dirURL)
  1008  	addFlagDirFormat(cmd.Flags(), &flags.dirFormat)
  1009  	cmd.Flags().BoolVarP(&flags.edit, flagEdit, "", false, "edit the created migration file(s)")
  1010  	return cmd
  1011  }
  1012  
  1013  func migrateNewRun(_ *cobra.Command, args []string, flags migrateNewFlags) error {
  1014  	u, err := url.Parse(flags.dirURL)
  1015  	if err != nil {
  1016  		return err
  1017  	}
  1018  	dir, err := dirURL(u, true)
  1019  	if err != nil {
  1020  		return err
  1021  	}
  1022  	if flags.edit {
  1023  		dir = &editDir{dir}
  1024  	}
  1025  	f, err := formatter(u)
  1026  	if err != nil {
  1027  		return err
  1028  	}
  1029  	var name string
  1030  	if len(args) > 0 {
  1031  		name = args[0]
  1032  	}
  1033  	return migrate.NewPlanner(nil, dir, migrate.PlanFormat(f)).WritePlan(&migrate.Plan{Name: name})
  1034  }
  1035  
  1036  type migrateSetFlags struct {
  1037  	url               string
  1038  	dirURL, dirFormat string
  1039  	revisionSchema    string
  1040  }
  1041  
  1042  // migrateSetCmd represents the 'atlas migrate set' subcommand.
  1043  func migrateSetCmd() *cobra.Command {
  1044  	var (
  1045  		flags migrateSetFlags
  1046  		cmd   = &cobra.Command{
  1047  			Use:   "set [flags] [version]",
  1048  			Short: "Set the current version of the migration history table.",
  1049  			Long: `'atlas migrate set' edits the revision table to consider all migrations up to and including the given version
  1050  to be applied. This command is usually used after manually making changes to the managed database.`,
  1051  			Example: `  atlas migrate set 3 --url mysql://user:pass@localhost:3306/
  1052    atlas migrate set --env local
  1053    atlas migrate set 1.2.4 --url mysql://user:pass@localhost:3306/my_db --revision-schema my_revisions`,
  1054  			PreRunE: func(cmd *cobra.Command, _ []string) error {
  1055  				if err := migrateFlagsFromConfig(cmd); err != nil {
  1056  					return err
  1057  				}
  1058  				if err := dirFormatBC(flags.dirFormat, &flags.dirURL); err != nil {
  1059  					return err
  1060  				}
  1061  				return checkDir(cmd, flags.dirURL, false)
  1062  			},
  1063  			RunE: func(cmd *cobra.Command, args []string) error {
  1064  				return migrateSetRun(cmd, args, flags)
  1065  			},
  1066  		}
  1067  	)
  1068  	cmd.Flags().SortFlags = false
  1069  	addFlagURL(cmd.Flags(), &flags.url)
  1070  	addFlagDirURL(cmd.Flags(), &flags.dirURL)
  1071  	addFlagDirFormat(cmd.Flags(), &flags.dirFormat)
  1072  	addFlagRevisionSchema(cmd.Flags(), &flags.revisionSchema)
  1073  	return cmd
  1074  }
  1075  
  1076  func migrateSetRun(cmd *cobra.Command, args []string, flags migrateSetFlags) (rerr error) {
  1077  	ctx := cmd.Context()
  1078  	dir, err := dir(flags.dirURL, false)
  1079  	if err != nil {
  1080  		return err
  1081  	}
  1082  	files, err := dir.Files()
  1083  	if err != nil {
  1084  		return err
  1085  	}
  1086  	client, err := sqlclient.Open(ctx, flags.url)
  1087  	if err != nil {
  1088  		return err
  1089  	}
  1090  	defer client.Close()
  1091  	// Acquire a lock.
  1092  	if l, ok := client.Driver.(schema.Locker); ok {
  1093  		unlock, err := l.Lock(ctx, applyLockValue, 0)
  1094  		if err != nil {
  1095  			return fmt.Errorf("acquiring database lock: %w", err)
  1096  		}
  1097  		// If unlocking fails notify the user about it.
  1098  		defer func() { cobra.CheckErr(unlock()) }()
  1099  	}
  1100  	if err := checkRevisionSchemaClarity(cmd, client, flags.revisionSchema); err != nil {
  1101  		return err
  1102  	}
  1103  	// Ensure revision table exists.
  1104  	rrw, err := entRevisions(ctx, client, flags.revisionSchema)
  1105  	if err != nil {
  1106  		return err
  1107  	}
  1108  	if err := rrw.Migrate(ctx); err != nil {
  1109  		return err
  1110  	}
  1111  	// Wrap manipulation in a transaction.
  1112  	tx, err := client.Tx(ctx, nil)
  1113  	if err != nil {
  1114  		return err
  1115  	}
  1116  	defer func() {
  1117  		if rerr == nil {
  1118  			rerr = tx.Commit()
  1119  		} else if err2 := tx.Rollback(); err2 != nil {
  1120  			rerr = fmt.Errorf("%v: %w", err2, err)
  1121  		}
  1122  	}()
  1123  	rrw, err = entRevisions(ctx, tx.Client, flags.revisionSchema)
  1124  	if err != nil {
  1125  		return err
  1126  	}
  1127  	revs, err := rrw.ReadRevisions(ctx)
  1128  	if err != nil {
  1129  		return err
  1130  	}
  1131  	var version string
  1132  	switch n := len(args); {
  1133  	// Prevent the case where 'migrate set' is called without a version on
  1134  	// a clean database. i.e., we allow only removing or syncing revisions.
  1135  	case n == 0 && len(revs) > 0:
  1136  		// Calling set without a version and an empty
  1137  		// migration directory purges the revision table.
  1138  		if len(files) > 0 {
  1139  			version = files[len(files)-1].Version()
  1140  		}
  1141  	case n == 1:
  1142  		// Check if the target version does exist in the migration directory.
  1143  		if idx := migrate.FilesLastIndex(files, func(f migrate.File) bool {
  1144  			return f.Version() == args[0]
  1145  		}); idx == -1 {
  1146  			return fmt.Errorf("migration with version %q not found", args[0])
  1147  		}
  1148  		version = args[0]
  1149  	default:
  1150  		return fmt.Errorf("accepts 1 arg(s), received %d", n)
  1151  	}
  1152  	log := &cmdlog.MigrateSet{}
  1153  	for _, r := range revs {
  1154  		// Check all existing revisions and ensure they precede the given version. If we encounter a partially
  1155  		// applied revision, or one with errors, mark them "fixed".
  1156  		switch {
  1157  		// remove revision to keep linear history
  1158  		case r.Version > version:
  1159  			log.Removed(r)
  1160  			if err := rrw.DeleteRevision(ctx, r.Version); err != nil {
  1161  				return err
  1162  			}
  1163  		// keep, but if with error mark "fixed"
  1164  		case r.Version == version && (r.Error != "" || r.Total != r.Applied):
  1165  			log.Set(r)
  1166  			r.Type = migrate.RevisionTypeExecute | migrate.RevisionTypeResolved
  1167  			if err := rrw.WriteRevision(ctx, r); err != nil {
  1168  				return err
  1169  			}
  1170  		}
  1171  	}
  1172  	revs, err = rrw.ReadRevisions(ctx)
  1173  	if err != nil {
  1174  		return err
  1175  	}
  1176  	// If the target version succeeds the last revision, mark
  1177  	// migrations applied, until we reach the target version.
  1178  	var pending []migrate.File
  1179  	switch {
  1180  	case len(revs) == 0:
  1181  		// Take every file until we reach target version.
  1182  		for _, f := range files {
  1183  			if f.Version() > version {
  1184  				break
  1185  			}
  1186  			pending = append(pending, f)
  1187  		}
  1188  	case version > revs[len(revs)-1].Version:
  1189  	loop:
  1190  		// Take every file succeeding the last revision until we reach target version.
  1191  		for _, f := range files {
  1192  			switch {
  1193  			case f.Version() <= revs[len(revs)-1].Version:
  1194  				// Migration precedes last revision.
  1195  			case f.Version() > version:
  1196  				// Migration succeeds target revision.
  1197  				break loop
  1198  			default: // between last revision and target
  1199  				pending = append(pending, f)
  1200  			}
  1201  		}
  1202  	}
  1203  	// Mark every pending file as applied.
  1204  	sum, err := dir.Checksum()
  1205  	if err != nil {
  1206  		return err
  1207  	}
  1208  	for _, f := range pending {
  1209  		h, err := sum.SumByName(f.Name())
  1210  		if err != nil {
  1211  			return err
  1212  		}
  1213  		rev := &migrate.Revision{
  1214  			Version:         f.Version(),
  1215  			Description:     f.Desc(),
  1216  			Type:            migrate.RevisionTypeResolved,
  1217  			ExecutedAt:      time.Now(),
  1218  			Hash:            h,
  1219  			OperatorVersion: operatorVersion(),
  1220  		}
  1221  		log.Set(rev)
  1222  		if err := rrw.WriteRevision(ctx, rev); err != nil {
  1223  			return err
  1224  		}
  1225  	}
  1226  	if log.Current, err = rrw.CurrentRevision(ctx); err != nil && !ent.IsNotFound(err) {
  1227  		return err
  1228  	}
  1229  	return cmdlog.MigrateSetTemplate.Execute(cmd.OutOrStdout(), log)
  1230  }
  1231  
  1232  type migrateStatusFlags struct {
  1233  	url               string
  1234  	dirURL, dirFormat string
  1235  	revisionSchema    string
  1236  	logFormat         string
  1237  }
  1238  
  1239  // migrateStatusCmd represents the 'atlas migrate status' subcommand.
  1240  func migrateStatusCmd() *cobra.Command {
  1241  	var (
  1242  		flags migrateStatusFlags
  1243  		cmd   = &cobra.Command{
  1244  			Use:   "status [flags]",
  1245  			Short: "Get information about the current migration status.",
  1246  			Long:  `'atlas migrate status' reports information about the current status of a connected database compared to the migration directory.`,
  1247  			Example: `  atlas migrate status --url mysql://user:pass@localhost:3306/
  1248    atlas migrate status --url mysql://user:pass@localhost:3306/ --dir file:///path/to/migration/directory`,
  1249  			PreRunE: func(cmd *cobra.Command, _ []string) error {
  1250  				if err := migrateFlagsFromConfig(cmd); err != nil {
  1251  					return err
  1252  				}
  1253  				if err := dirFormatBC(flags.dirFormat, &flags.dirURL); err != nil {
  1254  					return err
  1255  				}
  1256  				return checkDir(cmd, flags.dirURL, false)
  1257  			},
  1258  			RunE: func(cmd *cobra.Command, args []string) error {
  1259  				return migrateStatusRun(cmd, args, flags)
  1260  			},
  1261  		}
  1262  	)
  1263  	cmd.Flags().SortFlags = false
  1264  	addFlagURL(cmd.Flags(), &flags.url)
  1265  	addFlagLog(cmd.Flags(), &flags.logFormat)
  1266  	addFlagFormat(cmd.Flags(), &flags.logFormat)
  1267  	addFlagDirURL(cmd.Flags(), &flags.dirURL)
  1268  	addFlagDirFormat(cmd.Flags(), &flags.dirFormat)
  1269  	addFlagRevisionSchema(cmd.Flags(), &flags.revisionSchema)
  1270  	cmd.MarkFlagsMutuallyExclusive(flagLog, flagFormat)
  1271  	return cmd
  1272  }
  1273  
  1274  func migrateStatusRun(cmd *cobra.Command, _ []string, flags migrateStatusFlags) error {
  1275  	dir, err := dir(flags.dirURL, false)
  1276  	if err != nil {
  1277  		return err
  1278  	}
  1279  	client, err := sqlclient.Open(cmd.Context(), flags.url)
  1280  	if err != nil {
  1281  		return err
  1282  	}
  1283  	defer client.Close()
  1284  	if err := checkRevisionSchemaClarity(cmd, client, flags.revisionSchema); err != nil {
  1285  		return err
  1286  	}
  1287  	report, err := (&cmdmigrate.StatusReporter{
  1288  		Client: client,
  1289  		Dir:    dir,
  1290  		Schema: revisionSchemaName(client, flags.revisionSchema),
  1291  	}).Report(cmd.Context())
  1292  	if err != nil {
  1293  		return err
  1294  	}
  1295  	format := cmdlog.MigrateStatusTemplate
  1296  	if f := flags.logFormat; f != "" {
  1297  		if format, err = template.New("format").Funcs(cmdlog.StatusTemplateFuncs).Parse(f); err != nil {
  1298  			return fmt.Errorf("parse format: %w", err)
  1299  		}
  1300  	}
  1301  	return format.Execute(cmd.OutOrStdout(), report)
  1302  }
  1303  
  1304  type migrateValidateFlags struct {
  1305  	devURL            string
  1306  	dirURL, dirFormat string
  1307  }
  1308  
  1309  // migrateValidateCmd represents the 'atlas migrate validate' subcommand.
  1310  func migrateValidateCmd() *cobra.Command {
  1311  	var (
  1312  		flags migrateValidateFlags
  1313  		cmd   = &cobra.Command{
  1314  			Use:   "validate [flags]",
  1315  			Short: "Validates the migration directories checksum and SQL statements.",
  1316  			Long: `'atlas migrate validate' computes the integrity hash sum of the migration directory and compares it to the
  1317  atlas.sum file. If there is a mismatch it will be reported. If the --dev-url flag is given, the migration
  1318  files are executed on the connected database in order to validate SQL semantics.`,
  1319  			Example: `  atlas migrate validate
  1320    atlas migrate validate --dir file:///path/to/migration/directory
  1321    atlas migrate validate --dir file:///path/to/migration/directory --dev-url mysql://user:pass@localhost:3306/dev
  1322    atlas migrate validate --env dev`,
  1323  			PreRunE: func(cmd *cobra.Command, _ []string) error {
  1324  				if err := migrateFlagsFromConfig(cmd); err != nil {
  1325  					return err
  1326  				}
  1327  				if err := dirFormatBC(flags.dirFormat, &flags.dirURL); err != nil {
  1328  					return err
  1329  				}
  1330  				return checkDir(cmd, flags.dirURL, false)
  1331  			},
  1332  			RunE: func(cmd *cobra.Command, args []string) error {
  1333  				return migrateValidateRun(cmd, args, flags)
  1334  			},
  1335  		}
  1336  	)
  1337  	cmd.Flags().SortFlags = false
  1338  	addFlagDevURL(cmd.Flags(), &flags.devURL)
  1339  	addFlagDirURL(cmd.Flags(), &flags.dirURL)
  1340  	addFlagDirFormat(cmd.Flags(), &flags.dirFormat)
  1341  	return cmd
  1342  }
  1343  
  1344  func migrateValidateRun(cmd *cobra.Command, _ []string, flags migrateValidateFlags) error {
  1345  	// Validating the integrity is done by the PersistentPreRun already.
  1346  	if flags.devURL == "" {
  1347  		// If there is no --dev-url given do not attempt to replay the migration directory.
  1348  		return nil
  1349  	}
  1350  	// Open a client for the dev-db.
  1351  	dev, err := sqlclient.Open(cmd.Context(), flags.devURL)
  1352  	if err != nil {
  1353  		return err
  1354  	}
  1355  	defer dev.Close()
  1356  	// Currently, only our own migration file format is supported.
  1357  	dir, err := dir(flags.dirURL, false)
  1358  	if err != nil {
  1359  		return err
  1360  	}
  1361  	ex, err := migrate.NewExecutor(dev.Driver, dir, migrate.NopRevisionReadWriter{})
  1362  	if err != nil {
  1363  		return err
  1364  	}
  1365  	if _, err := ex.Replay(cmd.Context(), func() migrate.StateReader {
  1366  		if dev.URL.Schema != "" {
  1367  			return migrate.SchemaConn(dev, "", nil)
  1368  		}
  1369  		return migrate.RealmConn(dev, nil)
  1370  	}()); err != nil && !errors.Is(err, migrate.ErrNoPendingFiles) {
  1371  		return fmt.Errorf("replaying the migration directory: %w", err)
  1372  	}
  1373  	return nil
  1374  }
  1375  
  1376  const applyLockValue = "atlas_migrate_execute"
  1377  
  1378  func checkRevisionSchemaClarity(cmd *cobra.Command, c *sqlclient.Client, revisionSchemaFlag string) error {
  1379  	// The "old" default  behavior for the revision schema location was to store the revision table in its own schema.
  1380  	// Now, the table is saved in the connected schema, if any. To keep the backwards compatability, we now require
  1381  	// for schema bound connections to have the schema-revision flag present if there is no revision table in the schema
  1382  	// but the old default schema does have one.
  1383  	if c.URL.Schema != "" && revisionSchemaFlag == "" {
  1384  		// If the schema does not contain a revision table, but we can find a table in the previous default schema,
  1385  		// abort and tell the user to specify the intention.
  1386  		opts := &schema.InspectOptions{Tables: []string{revision.Table}}
  1387  		s, err := c.InspectSchema(cmd.Context(), "", opts)
  1388  		var ok bool
  1389  		switch {
  1390  		case schema.IsNotExistError(err):
  1391  			// If the schema does not exist, the table does not as well.
  1392  		case err != nil:
  1393  			return err
  1394  		default:
  1395  			// Connected schema does exist, check if the table does.
  1396  			_, ok = s.Table(revision.Table)
  1397  		}
  1398  		if !ok { // Either schema or table does not exist.
  1399  			// Check for the old default schema. If it does not exist, we have no problem.
  1400  			s, err := c.InspectSchema(cmd.Context(), defaultRevisionSchema, opts)
  1401  			switch {
  1402  			case schema.IsNotExistError(err):
  1403  				// Schema does not exist, we can proceed.
  1404  			case err != nil:
  1405  				return err
  1406  			default:
  1407  				if _, ok := s.Table(revision.Table); ok {
  1408  					fmt.Fprintf(cmd.OutOrStderr(),
  1409  						`We couldn't find a revision table in the connected schema but found one in 
  1410  the schema 'atlas_schema_revisions' and cannot determine the desired behavior.
  1411  
  1412  As a safety guard, we require you to specify whether to use the existing
  1413  table in 'atlas_schema_revisions' or create a new one in the connected schema
  1414  by providing the '--revisions-schema' flag or deleting the 'atlas_schema_revisions'
  1415  schema if it is unused.
  1416  
  1417  `)
  1418  					cmd.SilenceUsage = true
  1419  					cmd.SilenceErrors = true
  1420  					return errors.New("ambiguous revision table")
  1421  				}
  1422  			}
  1423  		}
  1424  	}
  1425  	return nil
  1426  }
  1427  
  1428  func entRevisions(ctx context.Context, c *sqlclient.Client, flag string) (*cmdmigrate.EntRevisions, error) {
  1429  	return cmdmigrate.NewEntRevisions(ctx, c, cmdmigrate.WithSchema(revisionSchemaName(c, flag)))
  1430  }
  1431  
  1432  // defaultRevisionSchema is the default schema for storing revisions table.
  1433  const defaultRevisionSchema = "atlas_schema_revisions"
  1434  
  1435  func revisionSchemaName(c *sqlclient.Client, flag string) string {
  1436  	switch {
  1437  	case flag != "":
  1438  		return flag
  1439  	case c.URL.Schema != "":
  1440  		return c.URL.Schema
  1441  	default:
  1442  		return defaultRevisionSchema
  1443  	}
  1444  }
  1445  
  1446  const (
  1447  	txModeNone = "none"
  1448  	txModeAll  = "all"
  1449  	txModeFile = "file"
  1450  )
  1451  
  1452  // tx handles wrapping migration execution in transactions.
  1453  type tx struct {
  1454  	dryRun       bool
  1455  	mode, schema string
  1456  	c            *sqlclient.Client
  1457  	rrw          migrate.RevisionReadWriter
  1458  	// current transaction context.
  1459  	tx    *sqlclient.TxClient
  1460  	txrrw migrate.RevisionReadWriter
  1461  }
  1462  
  1463  // driverFor returns the migrate.Driver to use to execute migration statements.
  1464  func (tx *tx) driverFor(ctx context.Context, f migrate.File) (migrate.Driver, migrate.RevisionReadWriter, error) {
  1465  	if tx.dryRun {
  1466  		// If the --dry-run flag is given we don't want to execute any statements on the database.
  1467  		return &dryRunDriver{tx.c.Driver}, &dryRunRevisions{tx.rrw}, nil
  1468  	}
  1469  	mode, err := tx.modeFor(f)
  1470  	if err != nil {
  1471  		return nil, nil, err
  1472  	}
  1473  	switch mode {
  1474  	case txModeNone:
  1475  		return tx.c.Driver, tx.rrw, nil
  1476  	case txModeFile:
  1477  		// In file-mode, this function is called each time a new file is executed. Open a transaction.
  1478  		if tx.tx != nil {
  1479  			return nil, nil, errors.New("unexpected active transaction")
  1480  		}
  1481  		var err error
  1482  		tx.tx, err = tx.c.Tx(ctx, nil)
  1483  		if err != nil {
  1484  			return nil, nil, err
  1485  		}
  1486  		if tx.txrrw, err = entRevisions(ctx, tx.tx.Client, tx.schema); err != nil {
  1487  			return nil, nil, err
  1488  		}
  1489  		return tx.tx.Driver, tx.txrrw, nil
  1490  	case txModeAll:
  1491  		// In file-mode, this function is called each time a new file is executed. Since we wrap all files into one
  1492  		// huge transaction, if there already is an opened one, use that.
  1493  		if tx.tx == nil {
  1494  			var err error
  1495  			tx.tx, err = tx.c.Tx(ctx, nil)
  1496  			if err != nil {
  1497  				return nil, nil, err
  1498  			}
  1499  			if tx.txrrw, err = entRevisions(ctx, tx.tx.Client, tx.schema); err != nil {
  1500  				return nil, nil, err
  1501  			}
  1502  		}
  1503  		return tx.tx.Driver, tx.txrrw, nil
  1504  	default:
  1505  		return nil, nil, fmt.Errorf("unknown tx-mode %q", mode)
  1506  	}
  1507  }
  1508  
  1509  // mayRollback may roll back a transaction depending on the given transaction mode.
  1510  func (tx *tx) mayRollback(err error) error {
  1511  	if tx.tx != nil && err != nil {
  1512  		if err2 := tx.tx.Rollback(); err2 != nil {
  1513  			err = fmt.Errorf("%v: %w", err2, err)
  1514  		}
  1515  	}
  1516  	return err
  1517  }
  1518  
  1519  // mayCommit may commit a transaction depending on the given transaction mode.
  1520  func (tx *tx) mayCommit() error {
  1521  	// Only commit if each file is wrapped in a transaction.
  1522  	if tx.tx != nil && !tx.dryRun && tx.mode == txModeFile {
  1523  		return tx.commit()
  1524  	}
  1525  	return nil
  1526  }
  1527  
  1528  // commit the transaction, if one is active.
  1529  func (tx *tx) commit() error {
  1530  	if tx.tx == nil {
  1531  		return nil
  1532  	}
  1533  	defer func() { tx.tx, tx.txrrw = nil, nil }()
  1534  	return tx.tx.Commit()
  1535  }
  1536  
  1537  func (tx *tx) modeFor(f migrate.File) (string, error) {
  1538  	l, ok := f.(*migrate.LocalFile)
  1539  	if !ok {
  1540  		return tx.mode, nil
  1541  	}
  1542  	switch ds := l.Directive("txmode"); {
  1543  	case len(ds) > 1:
  1544  		return "", fmt.Errorf("multiple txmode values found in file %q: %q", f.Name(), ds)
  1545  	case len(ds) == 0 || ds[0] == tx.mode:
  1546  		return tx.mode, nil
  1547  	case ds[0] == txModeAll:
  1548  		return "", fmt.Errorf("txmode %q is not allowed in file directive %q", txModeAll, f.Name())
  1549  	case ds[0] == txModeNone, ds[0] == txModeFile:
  1550  		if tx.mode == txModeAll {
  1551  			return "", fmt.Errorf("cannot set txmode directive to %q in %q when txmode %q is set globally", ds[0], f.Name(), txModeAll)
  1552  		}
  1553  		return ds[0], nil
  1554  	default:
  1555  		return "", fmt.Errorf("unknown txmode %q found in file directive %q", ds[0], f.Name())
  1556  	}
  1557  }
  1558  
  1559  func operatorVersion() string {
  1560  	v, _ := parseV(version)
  1561  	return "Atlas CLI " + v
  1562  }
  1563  
  1564  // dir parses u and calls dirURL.
  1565  func dir(u string, create bool) (migrate.Dir, error) {
  1566  	parsed, err := url.Parse(u)
  1567  	if err != nil {
  1568  		return nil, err
  1569  	}
  1570  	return dirURL(parsed, create)
  1571  }
  1572  
  1573  // dirURL returns a migrate.Dir to use as migration directory. For now only local directories are supported.
  1574  func dirURL(u *url.URL, create bool) (migrate.Dir, error) {
  1575  	path := filepath.Join(u.Host, u.Path)
  1576  	switch u.Scheme {
  1577  	case "mem":
  1578  		return migrate.OpenMemDir(path), nil
  1579  	case "file":
  1580  		if path == "" {
  1581  			path = "migrations"
  1582  		}
  1583  	default:
  1584  		return nil, fmt.Errorf("unsupported driver %q", u.Scheme)
  1585  	}
  1586  	fn := func() (migrate.Dir, error) { return migrate.NewLocalDir(path) }
  1587  	switch f := u.Query().Get("format"); f {
  1588  	case "", formatAtlas:
  1589  		// this is the default
  1590  	case formatGolangMigrate:
  1591  		fn = func() (migrate.Dir, error) { return sqltool.NewGolangMigrateDir(path) }
  1592  	case formatGoose:
  1593  		fn = func() (migrate.Dir, error) { return sqltool.NewGooseDir(path) }
  1594  	case formatFlyway:
  1595  		fn = func() (migrate.Dir, error) { return sqltool.NewFlywayDir(path) }
  1596  	case formatLiquibase:
  1597  		fn = func() (migrate.Dir, error) { return sqltool.NewLiquibaseDir(path) }
  1598  	case formatDBMate:
  1599  		fn = func() (migrate.Dir, error) { return sqltool.NewDBMateDir(path) }
  1600  	default:
  1601  		return nil, fmt.Errorf("unknown dir format %q", f)
  1602  	}
  1603  	d, err := fn()
  1604  	if create && errors.Is(err, fs.ErrNotExist) {
  1605  		if err := os.MkdirAll(path, 0755); err != nil {
  1606  			return nil, err
  1607  		}
  1608  		d, err = fn()
  1609  		if err != nil {
  1610  			return nil, err
  1611  		}
  1612  	}
  1613  	return d, err
  1614  }
  1615  
  1616  // dirFormatBC ensures the soon-to-be deprecated --dir-format flag gets set on all migration directory URLs.
  1617  func dirFormatBC(flag string, urls ...*string) error {
  1618  	for _, s := range urls {
  1619  		u, err := url.Parse(*s)
  1620  		if err != nil {
  1621  			return err
  1622  		}
  1623  		if !u.Query().Has("format") && flag != "" {
  1624  			q := u.Query()
  1625  			q.Set("format", flag)
  1626  			u.RawQuery = q.Encode()
  1627  			*s = u.String()
  1628  		}
  1629  	}
  1630  	return nil
  1631  }
  1632  
  1633  func checkDir(cmd *cobra.Command, url string, create bool) error {
  1634  	d, err := dir(url, create)
  1635  	if err != nil {
  1636  		return err
  1637  	}
  1638  	if err = migrate.Validate(d); err != nil {
  1639  		printChecksumError(cmd)
  1640  		return err
  1641  	}
  1642  	return nil
  1643  }
  1644  
  1645  func printChecksumError(cmd *cobra.Command) {
  1646  	fmt.Fprintf(cmd.OutOrStderr(), `You have a checksum error in your migration directory.
  1647  This happens if you manually create or edit a migration file.
  1648  Please check your migration files and run
  1649  
  1650  'atlas migrate hash'
  1651  
  1652  to re-hash the contents and resolve the error
  1653  
  1654  `)
  1655  	cmd.SilenceUsage = true
  1656  }
  1657  
  1658  // selectScheme validates the scheme of the provided to urls and returns the selected
  1659  // url scheme. Currently, all URLs must be of the same scheme, and only multiple
  1660  // "file://" URLs are allowed.
  1661  func selectScheme(urls []string) (string, error) {
  1662  	var scheme string
  1663  	if len(urls) == 0 {
  1664  		return "", errors.New("at least one url is required")
  1665  	}
  1666  	for _, u := range urls {
  1667  		parts := strings.SplitN(u, "://", 2)
  1668  		switch current := parts[0]; {
  1669  		case len(parts) == 1:
  1670  			ex := filepath.Ext(u)
  1671  			switch f, err := os.Stat(u); {
  1672  			case err != nil:
  1673  			case f.IsDir(), ex == extSQL, ex == extHCL:
  1674  				return "", fmt.Errorf("missing scheme. Did you mean file://%s?", u)
  1675  			}
  1676  			return "", errors.New("missing scheme. See: https://atlasgo.io/url")
  1677  		case scheme == "":
  1678  			scheme = current
  1679  		case scheme != current:
  1680  			return "", fmt.Errorf("got mixed --to url schemes: %q and %q, the desired state must be provided from a single kind of source", scheme, current)
  1681  		case current != "file":
  1682  			return "", fmt.Errorf("got multiple --to urls of scheme %q, only multiple 'file://' urls are supported", current)
  1683  		}
  1684  	}
  1685  	return scheme, nil
  1686  }
  1687  
  1688  // parseHCLPaths parses the HCL files in the given paths. If a path represents a directory,
  1689  // its direct descendants will be considered, skipping any subdirectories. If a project file
  1690  // is present in the input paths, an error is returned.
  1691  func parseHCLPaths(paths ...string) (*hclparse.Parser, error) {
  1692  	p := hclparse.NewParser()
  1693  	for _, path := range paths {
  1694  		switch stat, err := os.Stat(path); {
  1695  		case err != nil:
  1696  			return nil, err
  1697  		case stat.IsDir():
  1698  			dir, err := os.ReadDir(path)
  1699  			if err != nil {
  1700  				return nil, err
  1701  			}
  1702  			for _, f := range dir {
  1703  				// Skip nested dirs.
  1704  				if f.IsDir() {
  1705  					continue
  1706  				}
  1707  				if err := mayParse(p, filepath.Join(path, f.Name())); err != nil {
  1708  					return nil, err
  1709  				}
  1710  			}
  1711  		default:
  1712  			if err := mayParse(p, path); err != nil {
  1713  				return nil, err
  1714  			}
  1715  		}
  1716  	}
  1717  	if len(p.Files()) == 0 {
  1718  		return nil, fmt.Errorf("no schema files found in: %s", paths)
  1719  	}
  1720  	return p, nil
  1721  }
  1722  
  1723  // mayParse will parse the file in path if it is an HCL file. If the file is an Atlas
  1724  // project file an error is returned.
  1725  func mayParse(p *hclparse.Parser, path string) error {
  1726  	if n := filepath.Base(path); filepath.Ext(n) != extHCL {
  1727  		return nil
  1728  	}
  1729  	switch f, diag := p.ParseHCLFile(path); {
  1730  	case diag.HasErrors():
  1731  		return diag
  1732  	case isProjectFile(f):
  1733  		return fmt.Errorf("cannot parse project file %q as a schema file", path)
  1734  	default:
  1735  		return nil
  1736  	}
  1737  }
  1738  
  1739  func isProjectFile(f *hcl.File) bool {
  1740  	for _, blk := range f.Body.(*hclsyntax.Body).Blocks {
  1741  		if blk.Type == "env" {
  1742  			return true
  1743  		}
  1744  	}
  1745  	return false
  1746  }
  1747  
  1748  const (
  1749  	formatAtlas         = "atlas"
  1750  	formatGolangMigrate = "golang-migrate"
  1751  	formatGoose         = "goose"
  1752  	formatFlyway        = "flyway"
  1753  	formatLiquibase     = "liquibase"
  1754  	formatDBMate        = "dbmate"
  1755  )
  1756  
  1757  func formatter(u *url.URL) (migrate.Formatter, error) {
  1758  	switch f := u.Query().Get("format"); f {
  1759  	case formatAtlas:
  1760  		return migrate.DefaultFormatter, nil
  1761  	case formatGolangMigrate:
  1762  		return sqltool.GolangMigrateFormatter, nil
  1763  	case formatGoose:
  1764  		return sqltool.GooseFormatter, nil
  1765  	case formatFlyway:
  1766  		return sqltool.FlywayFormatter, nil
  1767  	case formatLiquibase:
  1768  		return sqltool.LiquibaseFormatter, nil
  1769  	case formatDBMate:
  1770  		return sqltool.DBMateFormatter, nil
  1771  	default:
  1772  		return nil, fmt.Errorf("unknown format %q", f)
  1773  	}
  1774  }
  1775  
  1776  func migrateFlagsFromConfig(cmd *cobra.Command) error {
  1777  	env, err := selectEnv(cmd)
  1778  	if err != nil {
  1779  		return err
  1780  	}
  1781  	return setMigrateEnvFlags(cmd, env)
  1782  }
  1783  
  1784  func setMigrateEnvFlags(cmd *cobra.Command, env *Env) error {
  1785  	if err := inputValuesFromEnv(cmd, env); err != nil {
  1786  		return err
  1787  	}
  1788  	if err := maySetFlag(cmd, flagURL, env.URL); err != nil {
  1789  		return err
  1790  	}
  1791  	if err := maySetFlag(cmd, flagDevURL, env.DevURL); err != nil {
  1792  		return err
  1793  	}
  1794  	if err := maySetFlag(cmd, flagDirURL, env.Migration.Dir); err != nil {
  1795  		return err
  1796  	}
  1797  	if err := maySetFlag(cmd, flagDirFormat, env.Migration.Format); err != nil {
  1798  		return err
  1799  	}
  1800  	if err := maySetFlag(cmd, flagBaseline, env.Migration.Baseline); err != nil {
  1801  		return err
  1802  	}
  1803  	if err := maySetFlag(cmd, flagRevisionSchema, env.Migration.RevisionsSchema); err != nil {
  1804  		return err
  1805  	}
  1806  	switch cmd.Name() {
  1807  	case "apply":
  1808  		if err := maySetFlag(cmd, flagFormat, env.Format.Migrate.Apply); err != nil {
  1809  			return err
  1810  		}
  1811  		if err := maySetFlag(cmd, flagLockTimeout, env.Migration.LockTimeout); err != nil {
  1812  			return err
  1813  		}
  1814  	case "diff":
  1815  		if err := maySetFlag(cmd, flagLockTimeout, env.Migration.LockTimeout); err != nil {
  1816  			return err
  1817  		}
  1818  		if err := maySetFlag(cmd, flagFormat, env.Format.Migrate.Diff); err != nil {
  1819  			return err
  1820  		}
  1821  	case "lint":
  1822  		if err := maySetFlag(cmd, flagFormat, env.Format.Migrate.Lint); err != nil {
  1823  			return err
  1824  		}
  1825  		if err := maySetFlag(cmd, flagFormat, env.Lint.Format); err != nil {
  1826  			return err
  1827  		}
  1828  		if err := maySetFlag(cmd, flagLatest, strconv.Itoa(env.Lint.Latest)); err != nil {
  1829  			return err
  1830  		}
  1831  		if err := maySetFlag(cmd, flagGitDir, env.Lint.Git.Dir); err != nil {
  1832  			return err
  1833  		}
  1834  		if err := maySetFlag(cmd, flagGitBase, env.Lint.Git.Base); err != nil {
  1835  			return err
  1836  		}
  1837  	case "status":
  1838  		if err := maySetFlag(cmd, flagFormat, env.Format.Migrate.Status); err != nil {
  1839  			return err
  1840  		}
  1841  	}
  1842  	// Transform "src" to a URL.
  1843  	srcs, err := env.Sources()
  1844  	if err != nil {
  1845  		return err
  1846  	}
  1847  	for i, s := range srcs {
  1848  		if isURL(s) {
  1849  			continue
  1850  		}
  1851  		if s, err = filepath.Abs(s); err != nil {
  1852  			return fmt.Errorf("finding abs path to source: %q: %w", s, err)
  1853  		}
  1854  		srcs[i] = "file://" + s
  1855  	}
  1856  	if err := maySetFlag(cmd, flagTo, strings.Join(srcs, ",")); err != nil {
  1857  		return err
  1858  	}
  1859  	if err := maySetFlag(cmd, flagSchema, strings.Join(env.Schemas, ",")); err != nil {
  1860  		return err
  1861  	}
  1862  	return nil
  1863  }
  1864  
  1865  // isURL returns true if the given string
  1866  // is an Atlas URL with a scheme.
  1867  func isURL(s string) bool {
  1868  	u, err := url.Parse(s)
  1869  	return err == nil && u.Scheme != ""
  1870  }
  1871  
  1872  // cmdEnvsRun executes a given command on each of the configured environment.
  1873  func cmdEnvsRun(
  1874  	envs []*Env,
  1875  	setFlags func(*cobra.Command, *Env) error,
  1876  	cmd *cobra.Command,
  1877  	runCmd func(*Env) error,
  1878  ) error {
  1879  	var (
  1880  		w     bytes.Buffer
  1881  		out   = cmd.OutOrStdout()
  1882  		reset = resetFromEnv(cmd)
  1883  	)
  1884  	cmd.SetOut(io.MultiWriter(out, &w))
  1885  	defer cmd.SetOut(out)
  1886  	for i, e := range envs {
  1887  		if err := setFlags(cmd, e); err != nil {
  1888  			return err
  1889  		}
  1890  		if err := runCmd(e); err != nil {
  1891  			return err
  1892  		}
  1893  		b := bytes.TrimLeft(w.Bytes(), " \t\r")
  1894  		// In case a custom logging was configured, ensure there is
  1895  		// a newline separator between the different environments.
  1896  		if cmd.Flags().Changed(flagFormat) && bytes.LastIndexByte(b, '\n') != len(b)-1 && i != len(envs)-1 {
  1897  			cmd.Println()
  1898  		}
  1899  		reset()
  1900  		w.Reset()
  1901  	}
  1902  	return nil
  1903  }
  1904  
  1905  type editDir struct{ migrate.Dir }
  1906  
  1907  // WriteFile implements the migrate.Dir.WriteFile method.
  1908  func (d *editDir) WriteFile(name string, b []byte) (err error) {
  1909  	if name != migrate.HashFileName {
  1910  		if b, err = edit(name, b); err != nil {
  1911  			return err
  1912  		}
  1913  	}
  1914  	return d.Dir.WriteFile(name, b)
  1915  }
  1916  
  1917  // edit allows editing the file content using editor.
  1918  func edit(name string, src []byte) ([]byte, error) {
  1919  	path := filepath.Join(os.TempDir(), name)
  1920  	if err := os.WriteFile(path, src, 0644); err != nil {
  1921  		return nil, fmt.Errorf("write source content to temp file: %w", err)
  1922  	}
  1923  	defer os.Remove(path)
  1924  	editor := "vi"
  1925  	if e := os.Getenv("EDITOR"); e != "" {
  1926  		editor = e
  1927  	}
  1928  	cmd := exec.Command("sh", "-c", editor+" "+path)
  1929  	cmd.Stdin = os.Stdin
  1930  	cmd.Stdout = os.Stdout
  1931  	cmd.Stderr = os.Stderr
  1932  	if err := cmd.Run(); err != nil {
  1933  		return nil, fmt.Errorf("exec edit: %w", err)
  1934  	}
  1935  	b, err := os.ReadFile(path)
  1936  	if err != nil {
  1937  		return nil, fmt.Errorf("read edited temp file: %w", err)
  1938  	}
  1939  	return b, nil
  1940  }
  1941  
  1942  type (
  1943  	// dryRunDriver wraps a migrate.Driver without executing any SQL statements.
  1944  	dryRunDriver struct{ migrate.Driver }
  1945  
  1946  	// dryRunRevisions wraps a migrate.RevisionReadWriter without executing any SQL statements.
  1947  	dryRunRevisions struct{ migrate.RevisionReadWriter }
  1948  )
  1949  
  1950  // QueryContext overrides the wrapped schema.ExecQuerier to not execute any SQL.
  1951  func (dryRunDriver) QueryContext(context.Context, string, ...any) (*sql.Rows, error) {
  1952  	return nil, nil
  1953  }
  1954  
  1955  // ExecContext overrides the wrapped schema.ExecQuerier to not execute any SQL.
  1956  func (dryRunDriver) ExecContext(context.Context, string, ...any) (sql.Result, error) {
  1957  	return nil, nil
  1958  }
  1959  
  1960  // Lock implements the schema.Locker interface.
  1961  func (dryRunDriver) Lock(context.Context, string, time.Duration) (schema.UnlockFunc, error) {
  1962  	// We dry-run, we don't execute anything. Locking is not required.
  1963  	return func() error { return nil }, nil
  1964  }
  1965  
  1966  // CheckClean implements the migrate.CleanChecker interface.
  1967  func (dryRunDriver) CheckClean(context.Context, *migrate.TableIdent) error {
  1968  	return nil
  1969  }
  1970  
  1971  // Snapshot implements the migrate.Snapshoter interface.
  1972  func (dryRunDriver) Snapshot(context.Context) (migrate.RestoreFunc, error) {
  1973  	// We dry-run, we don't execute anything. Snapshotting not required.
  1974  	return func(context.Context) error { return nil }, nil
  1975  }
  1976  
  1977  // WriteRevision overrides the wrapped migrate.RevisionReadWriter to not saved any changes to revisions.
  1978  func (dryRunRevisions) WriteRevision(context.Context, *migrate.Revision) error {
  1979  	return nil
  1980  }