github.com/iasthc/atlas/cmd/atlas@v0.0.0-20230523071841-73246df3f88d/internal/cmdapi/cmdapi.go (about) 1 // Copyright 2021-present The Atlas Authors. All rights reserved. 2 // This source code is licensed under the Apache 2.0 license found 3 // in the LICENSE file in the root directory of this source tree. 4 5 // Package cmdapi holds the atlas commands used to build an atlas distribution. 6 package cmdapi 7 8 import ( 9 "context" 10 "encoding/csv" 11 "errors" 12 "fmt" 13 "io" 14 "net/url" 15 "os" 16 "path/filepath" 17 "strings" 18 "time" 19 20 "github.com/iasthc/atlas/cmd/atlas/internal/cmdext" 21 "github.com/iasthc/atlas/sql/migrate" 22 "github.com/iasthc/atlas/sql/schema" 23 "github.com/iasthc/atlas/sql/sqlclient" 24 25 "github.com/spf13/cobra" 26 "github.com/spf13/pflag" 27 "github.com/zclconf/go-cty/cty" 28 "golang.org/x/mod/semver" 29 ) 30 31 var ( 32 // Root represents the root command when called without any subcommands. 33 Root = &cobra.Command{ 34 Use: "atlas", 35 Short: "A database toolkit.", 36 SilenceUsage: true, 37 } 38 39 // GlobalFlags contains flags common to many Atlas sub-commands. 40 GlobalFlags struct { 41 // Config defines the path to the Atlas project/config file. 42 ConfigURL string 43 // SelectedEnv contains the environment selected from the active project via the --env flag. 44 SelectedEnv string 45 // Vars contains the input variables passed from the CLI to Atlas DDL or project files. 46 Vars Vars 47 } 48 49 // version holds Atlas version. When built with cloud packages should be set by build flag 50 // "-X 'github.com/iasthc/atlas/cmd/atlas/internal/cmdapi.version=${version}'" 51 version string 52 53 // schemaCmd represents the subcommand 'atlas version'. 54 versionCmd = &cobra.Command{ 55 Use: "version", 56 Short: "Prints this Atlas CLI version information.", 57 Run: func(cmd *cobra.Command, args []string) { 58 v, u := parseV(version) 59 cmd.Printf("atlas version %s\n%s\n", v, u) 60 }, 61 } 62 63 // license holds Atlas license. When built with cloud packages should be set by build flag 64 // "-X 'github.com/iasthc/atlas/cmd/atlas/internal/cmdapi.license=${license}'" 65 license = `LICENSE 66 Atlas is licensed under Apache 2.0 as found in https://github.com/ariga/atlas/blob/master/LICENSE.` 67 licenseCmd = &cobra.Command{ 68 Use: "license", 69 Short: "Display license information", 70 Run: func(cmd *cobra.Command, _ []string) { 71 cmd.Println(license) 72 }, 73 } 74 ) 75 76 func init() { 77 Root.AddCommand(versionCmd) 78 Root.AddCommand(licenseCmd) 79 // Register a global function to clean up the global 80 // flags regardless if the command passed or failed. 81 cobra.OnFinalize(func() { 82 GlobalFlags.ConfigURL = "" 83 GlobalFlags.Vars = nil 84 GlobalFlags.SelectedEnv = "" 85 }) 86 } 87 88 // inputValuesFromEnv populates GlobalFlags.Vars from the active environment. If we are working 89 // inside a project, the "var" flag is not propagated to the schema definition. Instead, it 90 // is used to evaluate the project file which can pass input values via the "values" block 91 // to the schema. 92 func inputValuesFromEnv(cmd *cobra.Command, env *Env) error { 93 if fl := cmd.Flag(flagVar); fl == nil { 94 return nil 95 } 96 values, err := env.asMap() 97 if err != nil { 98 return err 99 } 100 if len(values) == 0 { 101 return nil 102 } 103 pairs := make([]string, 0, len(values)) 104 for k, v := range values { 105 pairs = append(pairs, fmt.Sprintf("%s=%s", k, v)) 106 } 107 vars := strings.Join(pairs, ",") 108 if err := cmd.Flags().Set(flagVar, vars); err != nil { 109 return fmt.Errorf("set flag %q: %w", flagVar, err) 110 } 111 return nil 112 } 113 114 // parseV returns a user facing version and release notes url 115 func parseV(version string) (string, string) { 116 u := "https://github.com/ariga/atlas/releases/latest" 117 if ok := semver.IsValid(version); !ok { 118 return "- development", u 119 } 120 s := strings.Split(version, "-") 121 if len(s) != 0 && s[len(s)-1] != "canary" { 122 u = fmt.Sprintf("https://github.com/ariga/atlas/releases/tag/%s", version) 123 } 124 return version, u 125 } 126 127 // Version returns the current Atlas binary version. 128 func Version() string { 129 return version 130 } 131 132 // Vars implements pflag.Value. 133 type Vars map[string]cty.Value 134 135 // String implements pflag.Value.String. 136 func (v Vars) String() string { 137 var b strings.Builder 138 for k := range v { 139 if b.Len() > 0 { 140 b.WriteString(", ") 141 } 142 b.WriteString(k) 143 b.WriteString(":") 144 b.WriteString(v[k].GoString()) 145 } 146 return "[" + b.String() + "]" 147 } 148 149 // Copy returns a copy of the current variables. 150 func (v Vars) Copy() Vars { 151 vc := make(Vars) 152 for k := range v { 153 vc[k] = v[k] 154 } 155 return vc 156 } 157 158 // Replace overrides the variables. 159 func (v *Vars) Replace(vc Vars) { 160 *v = vc 161 } 162 163 // Set implements pflag.Value.Set. 164 func (v *Vars) Set(s string) error { 165 if *v == nil { 166 *v = make(Vars) 167 } 168 kvs, err := csv.NewReader(strings.NewReader(s)).Read() 169 if err != nil { 170 return err 171 } 172 for i := range kvs { 173 kv := strings.SplitN(kvs[i], "=", 2) 174 if len(kv) != 2 { 175 return fmt.Errorf("variables must be format as key=value, got: %q", kvs[i]) 176 } 177 v1 := cty.StringVal(kv[1]) 178 switch v0, ok := (*v)[kv[0]]; { 179 case ok && v0.Type().IsListType(): 180 (*v)[kv[0]] = cty.ListVal(append(v0.AsValueSlice(), v1)) 181 case ok: 182 (*v)[kv[0]] = cty.ListVal([]cty.Value{v0, v1}) 183 default: 184 (*v)[kv[0]] = v1 185 } 186 } 187 return nil 188 } 189 190 // Type implements pflag.Value.Type. 191 func (v *Vars) Type() string { 192 return "<name>=<value>" 193 } 194 195 const ( 196 flagAllowDirty = "allow-dirty" 197 flagEdit = "edit" 198 flagAutoApprove = "auto-approve" 199 flagBaseline = "baseline" 200 flagConfig = "config" 201 flagDevURL = "dev-url" 202 flagDirURL = "dir" 203 flagDirFormat = "dir-format" 204 flagDryRun = "dry-run" 205 flagEnv = "env" 206 flagExclude = "exclude" 207 flagFile = "file" 208 flagFrom = "from" 209 flagFromShort = "f" 210 flagFormat = "format" 211 flagGitBase = "git-base" 212 flagGitDir = "git-dir" 213 flagLatest = "latest" 214 flagLockTimeout = "lock-timeout" 215 flagLog = "log" 216 flagRevisionSchema = "revisions-schema" 217 flagSchema = "schema" 218 flagSchemaShort = "s" 219 flagTo = "to" 220 flagTxMode = "tx-mode" 221 flagURL = "url" 222 flagURLShort = "u" 223 flagVar = "var" 224 flagQualifier = "qualifier" 225 ) 226 227 func addGlobalFlags(set *pflag.FlagSet) { 228 set.StringVar(&GlobalFlags.SelectedEnv, flagEnv, "", "set which env from the config file to use") 229 set.Var(&GlobalFlags.Vars, flagVar, "input variables") 230 set.StringVarP(&GlobalFlags.ConfigURL, flagConfig, "c", defaultConfigPath, "select config (project) file using URL format") 231 } 232 233 func addFlagAutoApprove(set *pflag.FlagSet, target *bool) { 234 set.BoolVar(target, flagAutoApprove, false, "apply changes without prompting for approval") 235 } 236 237 func addFlagDirFormat(set *pflag.FlagSet, target *string) { 238 set.StringVar(target, flagDirFormat, "atlas", "select migration file format") 239 } 240 241 func addFlagLockTimeout(set *pflag.FlagSet, target *time.Duration) { 242 set.DurationVar(target, flagLockTimeout, 10*time.Second, "set how long to wait for the database lock") 243 } 244 245 // addFlagURL adds a URL flag. If given, args[0] override the name, args[1] the shorthand, args[2] the default value. 246 func addFlagDirURL(set *pflag.FlagSet, target *string, args ...string) { 247 name, short, val := flagDirURL, "", "file://migrations" 248 switch len(args) { 249 case 3: 250 val = args[2] 251 fallthrough 252 case 2: 253 short = args[1] 254 fallthrough 255 case 1: 256 name = args[0] 257 } 258 set.StringVarP(target, name, short, val, "select migration directory using URL format") 259 } 260 261 func addFlagDevURL(set *pflag.FlagSet, target *string) { 262 set.StringVar( 263 target, 264 flagDevURL, 265 "", 266 "[driver://username:password@address/dbname?param=value] select a dev database using the URL format", 267 ) 268 } 269 270 func addFlagDryRun(set *pflag.FlagSet, target *bool) { 271 set.BoolVar(target, flagDryRun, false, "print SQL without executing it") 272 } 273 274 func addFlagExclude(set *pflag.FlagSet, target *[]string) { 275 set.StringSliceVar( 276 target, 277 flagExclude, 278 nil, 279 "list of glob patterns used to filter resources from applying", 280 ) 281 } 282 283 func addFlagLog(set *pflag.FlagSet, target *string) { 284 set.StringVar(target, flagLog, "", "Go template to use to format the output") 285 // Use MarkHidden instead of MarkDeprecated to avoid 286 // spam users' system logs with deprecation warnings. 287 cobra.CheckErr(set.MarkHidden(flagLog)) 288 } 289 290 func addFlagFormat(set *pflag.FlagSet, target *string) { 291 set.StringVar(target, flagFormat, "", "Go template to use to format the output") 292 } 293 294 func addFlagRevisionSchema(set *pflag.FlagSet, target *string) { 295 set.StringVar(target, flagRevisionSchema, "", "name of the schema the revisions table resides in") 296 } 297 298 func addFlagSchemas(set *pflag.FlagSet, target *[]string) { 299 set.StringSliceVarP( 300 target, 301 flagSchema, flagSchemaShort, 302 nil, 303 "set schema names", 304 ) 305 } 306 307 // addFlagURL adds a URL flag. If given, args[0] override the name, args[1] the shorthand. 308 func addFlagURL(set *pflag.FlagSet, target *string, args ...string) { 309 name, short := flagURL, flagURLShort 310 switch len(args) { 311 case 2: 312 short = args[1] 313 fallthrough 314 case 1: 315 name = args[0] 316 } 317 set.StringVarP( 318 target, 319 name, short, 320 "", 321 "[driver://username:password@address/dbname?param=value] select a resource using the URL format", 322 ) 323 } 324 325 func addFlagURLs(set *pflag.FlagSet, target *[]string, args ...string) { 326 name, short := flagURL, flagURLShort 327 switch len(args) { 328 case 2: 329 short = args[1] 330 fallthrough 331 case 1: 332 name = args[0] 333 } 334 set.StringSliceVarP( 335 target, 336 name, short, 337 nil, 338 "[driver://username:password@address/dbname?param=value] select a resource using the URL format", 339 ) 340 } 341 342 func addFlagToURLs(set *pflag.FlagSet, target *[]string) { 343 set.StringSliceVarP(target, flagTo, "", nil, "[driver://username:password@address/dbname?param=value] select a desired state using the URL format") 344 } 345 346 // maySetFlag sets the flag with the provided name to envVal if such a flag exists 347 // on the cmd, it was not set by the user via the command line and if envVal is not 348 // an empty string. 349 func maySetFlag(cmd *cobra.Command, name, envVal string) error { 350 if f := cmd.Flag(name); f == nil || f.Changed || envVal == "" { 351 return nil 352 } 353 return cmd.Flags().Set(name, envVal) 354 } 355 356 // resetFromEnv traverses the command flags, records what flags 357 // were not set by the user and returns a callback to clear them 358 // after it was set by the current environment. 359 func resetFromEnv(cmd *cobra.Command) func() { 360 mayReset := make(map[string]func() error) 361 cmd.Flags().VisitAll(func(f *pflag.Flag) { 362 if f.Changed { 363 return 364 } 365 vs := f.Value.String() 366 r := func() error { return f.Value.Set(vs) } 367 if v, ok := f.Value.(*Vars); ok { 368 vs := v.Copy() 369 r = func() error { 370 v.Replace(vs) 371 return nil 372 } 373 } else if v, ok := f.Value.(pflag.SliceValue); ok { 374 vs := v.GetSlice() 375 r = func() error { 376 return v.Replace(vs) 377 } 378 } 379 mayReset[f.Name] = r 380 }) 381 return func() { 382 for name, reset := range mayReset { 383 if f := cmd.Flag(name); f != nil && f.Changed { 384 f.Changed = false 385 // Unexpected error, because this flag was set before. 386 cobra.CheckErr(reset()) 387 } 388 } 389 } 390 } 391 392 type ( 393 // stateReadCloser is a migrate.StateReader with an optional io.Closer. 394 stateReadCloser struct { 395 migrate.StateReader 396 io.Closer // optional close function 397 schema string // in case we work on a single schema 398 hcl bool // true if state was read from HCL files since in that case we always compare realms 399 } 400 // stateReaderConfig is given to stateReader. 401 stateReaderConfig struct { 402 urls []string // urls to create a migrate.StateReader from 403 client, dev *sqlclient.Client // database connections, while dev is considered a dev database, client is not 404 schemas []string // schemas to work on 405 exclude []string // exclude flag values 406 vars Vars 407 } 408 ) 409 410 // stateReader returns a migrate.StateReader that reads the state from the given urls. 411 func stateReader(ctx context.Context, config *stateReaderConfig) (*stateReadCloser, error) { 412 scheme, err := selectScheme(config.urls) 413 if err != nil { 414 return nil, err 415 } 416 parsed := make([]*url.URL, len(config.urls)) 417 for i, u := range config.urls { 418 parsed[i], err = url.Parse(u) 419 if err != nil { 420 return nil, err 421 } 422 } 423 switch scheme { 424 // "file" scheme is valid for both migration directory and HCL paths. 425 case "file": 426 switch ext, err := filesExt(parsed); { 427 case err != nil: 428 return nil, err 429 case ext == extHCL: 430 return hclStateReader(ctx, config, parsed) 431 case ext == extSQL: 432 return sqlStateReader(ctx, config, parsed) 433 default: 434 panic("unreachable") // checked by filesExt. 435 } 436 default: 437 // In case there is an external state-loader registered with this scheme. 438 if l, ok := cmdext.States.Loader(scheme); ok { 439 sr, err := l.LoadState(ctx, &cmdext.LoadStateOptions{URLs: parsed, Dev: config.dev}) 440 if err != nil { 441 return nil, err 442 } 443 rc := &stateReadCloser{StateReader: sr} 444 if config.dev != nil && config.dev.URL.Schema != "" { 445 rc.schema = config.dev.URL.Schema 446 } 447 return rc, nil 448 } 449 // All other schemes are database (or docker) connections. 450 c, err := sqlclient.Open(ctx, config.urls[0]) // call to selectScheme already checks for len > 0 451 if err != nil { 452 return nil, err 453 } 454 var sr migrate.StateReader 455 switch c.URL.Schema { 456 case "": 457 sr = migrate.RealmConn(c.Driver, &schema.InspectRealmOption{ 458 Schemas: config.schemas, 459 Exclude: config.exclude, 460 }) 461 default: 462 sr = migrate.SchemaConn(c.Driver, c.URL.Schema, &schema.InspectOptions{Exclude: config.exclude}) 463 } 464 return &stateReadCloser{ 465 StateReader: sr, 466 Closer: c, 467 schema: c.URL.Schema, 468 }, nil 469 } 470 } 471 472 // hclStateReadr returns a StateReader that reads the state from the given HCL paths urls. 473 func hclStateReader(ctx context.Context, config *stateReaderConfig, urls []*url.URL) (*stateReadCloser, error) { 474 var client *sqlclient.Client 475 switch { 476 case config.dev != nil: 477 client = config.dev 478 case config.client != nil: 479 client = config.client 480 default: 481 return nil, errors.New("--dev-url cannot be empty") 482 } 483 paths := make([]string, len(urls)) 484 for i, u := range urls { 485 paths[i] = filepath.Join(u.Host, u.Path) 486 } 487 parser, err := parseHCLPaths(paths...) 488 if err != nil { 489 return nil, err 490 } 491 realm := &schema.Realm{} 492 if err := client.Eval(parser, realm, config.vars); err != nil { 493 return nil, err 494 } 495 if len(config.schemas) > 0 { 496 // Validate all schemas in file were selected by user. 497 sm := make(map[string]bool, len(config.schemas)) 498 for _, s := range config.schemas { 499 sm[s] = true 500 } 501 for _, s := range realm.Schemas { 502 if !sm[s.Name] { 503 return nil, fmt.Errorf("schema %q from paths %q is not requested (all schemas in HCL must be requested)", s.Name, paths) 504 } 505 } 506 } 507 // In case the dev connection is bound to a specific schema, we require the 508 // desired schema to contain only one schema. Thus, executing diff will be 509 // done on the content of these two schema and not the whole realm. 510 if client.URL.Schema != "" && len(realm.Schemas) > 1 { 511 return nil, fmt.Errorf( 512 "cannot use HCL with more than 1 schema when dev-url is limited to schema %q", 513 config.dev.URL.Schema, 514 ) 515 } 516 if norm, ok := client.Driver.(schema.Normalizer); ok && config.dev != nil { // only normalize on a dev database 517 realm, err = norm.NormalizeRealm(ctx, realm) 518 if err != nil { 519 return nil, err 520 } 521 } 522 t := &stateReadCloser{StateReader: migrate.Realm(realm), hcl: true} 523 return t, nil 524 } 525 526 func sqlStateReader(ctx context.Context, config *stateReaderConfig, urls []*url.URL) (*stateReadCloser, error) { 527 if len(urls) != 1 { 528 return nil, fmt.Errorf("the provided SQL state must be either a single schema file or a migration directory, but %d paths were found", len(urls)) 529 } 530 // Replaying a migration directory requires a dev connection. 531 if config.dev == nil { 532 return nil, errors.New("--dev-url cannot be empty") 533 } 534 var ( 535 dir migrate.Dir 536 opts []migrate.ReplayOption 537 path = filepath.Join(urls[0].Host, urls[0].Path) 538 ) 539 switch fi, err := os.Stat(path); { 540 case err != nil: 541 return nil, err 542 // A single schema file. 543 case !fi.IsDir(): 544 b, err := os.ReadFile(path) 545 if err != nil { 546 return nil, err 547 } 548 dir = &validMemDir{} 549 if err := dir.WriteFile(fi.Name(), b); err != nil { 550 return nil, err 551 } 552 // A migration directory. 553 default: 554 if dir, err = dirURL(urls[0], false); err != nil { 555 return nil, err 556 } 557 if v := urls[0].Query().Get("version"); v != "" { 558 opts = append(opts, migrate.ReplayToVersion(v)) 559 } 560 } 561 ex, err := migrate.NewExecutor(config.dev.Driver, dir, migrate.NopRevisionReadWriter{}) 562 if err != nil { 563 return nil, err 564 } 565 sr, err := ex.Replay(ctx, func() migrate.StateReader { 566 if config.dev.URL.Schema != "" { 567 return migrate.SchemaConn(config.dev, "", nil) 568 } 569 return migrate.RealmConn(config.dev, &schema.InspectRealmOption{ 570 Schemas: config.schemas, 571 Exclude: config.exclude, 572 }) 573 }(), opts...) 574 if err != nil && !errors.Is(err, migrate.ErrNoPendingFiles) { 575 return nil, err 576 } 577 return &stateReadCloser{ 578 StateReader: migrate.Realm(sr), 579 schema: config.dev.URL.Schema, 580 }, nil 581 } 582 583 // Close redirects calls to Close to the enclosed io.Closer. 584 func (sr *stateReadCloser) Close() { 585 if sr.Closer != nil { 586 sr.Closer.Close() 587 } 588 } 589 590 // validMemDir will not throw an error when put into migrate.Validate. 591 type validMemDir struct{ migrate.MemDir } 592 593 func (d *validMemDir) Validate() error { return nil } 594 595 const ( 596 extHCL = ".hcl" 597 extSQL = ".sql" 598 ) 599 600 func filesExt(urls []*url.URL) (string, error) { 601 var path, ext string 602 set := func(curr string) error { 603 switch e := filepath.Ext(curr); { 604 case e != extHCL && e != extSQL: 605 return fmt.Errorf("unknown schema file: %q", curr) 606 case ext != "" && ext != e: 607 return fmt.Errorf("ambiguous schema: both SQL and HCL files found: %q, %q", path, curr) 608 default: 609 path, ext = curr, e 610 return nil 611 } 612 } 613 for _, u := range urls { 614 path := filepath.Join(u.Host, u.Path) 615 switch fi, err := os.Stat(path); { 616 case err != nil: 617 return "", err 618 case fi.IsDir(): 619 files, err := os.ReadDir(path) 620 if err != nil { 621 return "", err 622 } 623 for _, f := range files { 624 switch filepath.Ext(f.Name()) { 625 // Ignore unknown extensions in case we read directories. 626 case extHCL, extSQL: 627 if err := set(f.Name()); err != nil { 628 return "", err 629 } 630 } 631 } 632 default: 633 if err := set(fi.Name()); err != nil { 634 return "", err 635 } 636 } 637 } 638 switch { 639 case ext != "": 640 case len(urls) == 1 && (urls[0].Host != "" || urls[0].Path != ""): 641 return "", fmt.Errorf( 642 "%q contains neither SQL nor HCL files", 643 filepath.Base(filepath.Join(urls[0].Host, urls[0].Path)), 644 ) 645 default: 646 return "", errors.New("schema contains neither SQL nor HCL files") 647 } 648 return ext, nil 649 }