github.com/Ali-iotechsys/sqlboiler/v4@v4.0.0-20221208124957-6aec9a5f1f71/main.go (about)

     1  // Package main defines a command line interface for the sqlboiler package
     2  package main
     3  
     4  import (
     5  	"fmt"
     6  	"os"
     7  	"path/filepath"
     8  	"strings"
     9  
    10  	"github.com/friendsofgo/errors"
    11  	"github.com/spf13/cobra"
    12  	"github.com/spf13/viper"
    13  	"github.com/volatiletech/sqlboiler/v4/boilingcore"
    14  	"github.com/volatiletech/sqlboiler/v4/drivers"
    15  	"github.com/volatiletech/sqlboiler/v4/importers"
    16  )
    17  
    18  const sqlBoilerVersion = "IOTech" //"4.13.0"
    19  
    20  var (
    21  	flagConfigFile string
    22  	cmdState       *boilingcore.State
    23  	cmdConfig      *boilingcore.Config
    24  )
    25  
    26  func initConfig() {
    27  	if len(flagConfigFile) != 0 {
    28  		viper.SetConfigFile(flagConfigFile)
    29  		if err := viper.ReadInConfig(); err != nil {
    30  			fmt.Println("Can't read config:", err)
    31  			os.Exit(1)
    32  		}
    33  		return
    34  	}
    35  
    36  	var err error
    37  	viper.SetConfigName("sqlboiler")
    38  
    39  	configHome := os.Getenv("XDG_CONFIG_HOME")
    40  	homePath := os.Getenv("HOME")
    41  	wd, err := os.Getwd()
    42  	if err != nil {
    43  		wd = "."
    44  	}
    45  
    46  	configPaths := []string{wd}
    47  	if len(configHome) > 0 {
    48  		configPaths = append(configPaths, filepath.Join(configHome, "sqlboiler"))
    49  	} else {
    50  		configPaths = append(configPaths, filepath.Join(homePath, ".config/sqlboiler"))
    51  	}
    52  
    53  	for _, p := range configPaths {
    54  		viper.AddConfigPath(p)
    55  	}
    56  
    57  	// Ignore errors here, fallback to other validation methods.
    58  	// Users can use environment variables if a config is not found.
    59  	_ = viper.ReadInConfig()
    60  }
    61  
    62  func main() {
    63  	// Too much happens between here and cobra's argument handling, for
    64  	// something so simple just do it immediately.
    65  	for _, arg := range os.Args {
    66  		if arg == "--version" {
    67  			fmt.Println("SQLBoiler v" + sqlBoilerVersion)
    68  			return
    69  		}
    70  	}
    71  
    72  	// Set up the cobra root command
    73  	rootCmd := &cobra.Command{
    74  		Use:   "sqlboiler [flags] <driver>",
    75  		Short: "SQL Boiler generates an ORM tailored to your database schema.",
    76  		Long: "SQL Boiler generates a Go ORM from template files, tailored to your database schema.\n" +
    77  			`Complete documentation is available at http://github.com/volatiletech/sqlboiler`,
    78  		Example:       `sqlboiler psql`,
    79  		PreRunE:       preRun,
    80  		RunE:          run,
    81  		PostRunE:      postRun,
    82  		SilenceErrors: true,
    83  		SilenceUsage:  true,
    84  	}
    85  
    86  	cobra.OnInitialize(initConfig)
    87  
    88  	// Set up the cobra root command flags
    89  	rootCmd.PersistentFlags().StringVarP(&flagConfigFile, "config", "c", "", "Filename of config file to override default lookup")
    90  	rootCmd.PersistentFlags().StringP("output", "o", "models", "The name of the folder to output to")
    91  	rootCmd.PersistentFlags().StringP("pkgname", "p", "models", "The name you wish to assign to your generated package")
    92  	rootCmd.PersistentFlags().StringSliceP("templates", "", nil, "A templates directory, overrides the embedded template folders in sqlboiler")
    93  	rootCmd.PersistentFlags().StringSliceP("tag", "t", nil, "Struct tags to be included on your models in addition to json, yaml, toml")
    94  	rootCmd.PersistentFlags().StringSliceP("replace", "", nil, "Replace templates by directory: relpath/to_file.tpl:relpath/to_replacement.tpl")
    95  	rootCmd.PersistentFlags().BoolP("debug", "d", false, "Debug mode prints stack traces on error")
    96  	rootCmd.PersistentFlags().BoolP("no-context", "", false, "Disable context.Context usage in the generated code")
    97  	rootCmd.PersistentFlags().BoolP("no-tests", "", false, "Disable generated go test files")
    98  	rootCmd.PersistentFlags().BoolP("no-hooks", "", false, "Disable hooks feature for your models")
    99  	rootCmd.PersistentFlags().BoolP("no-rows-affected", "", false, "Disable rows affected in the generated API")
   100  	rootCmd.PersistentFlags().BoolP("no-auto-timestamps", "", false, "Disable automatic timestamps for created_at/updated_at")
   101  	rootCmd.PersistentFlags().BoolP("no-driver-templates", "", false, "Disable parsing of templates defined by the database driver")
   102  	rootCmd.PersistentFlags().BoolP("no-back-referencing", "", false, "Disable back referencing in the loaded relationship structs")
   103  	rootCmd.PersistentFlags().BoolP("always-wrap-errors", "", false, "Wrap all returned errors with stacktraces, also sql.ErrNoRows")
   104  	rootCmd.PersistentFlags().BoolP("add-global-variants", "", false, "Enable generation for global variants")
   105  	rootCmd.PersistentFlags().BoolP("add-panic-variants", "", false, "Enable generation for panic variants")
   106  	rootCmd.PersistentFlags().BoolP("add-soft-deletes", "", false, "Enable soft deletion by updating deleted_at timestamp")
   107  	rootCmd.PersistentFlags().BoolP("add-enum-types", "", false, "Enable generation of types for enums")
   108  	rootCmd.PersistentFlags().StringP("enum-null-prefix", "", "Null", "Name prefix of nullable enum types")
   109  	rootCmd.PersistentFlags().BoolP("version", "", false, "Print the version")
   110  	rootCmd.PersistentFlags().BoolP("wipe", "", false, "Delete the output folder (rm -rf) before generation to ensure sanity")
   111  	rootCmd.PersistentFlags().StringP("struct-tag-casing", "", "snake", "Decides the casing for go structure tag names. camel, title or snake (default snake)")
   112  	rootCmd.PersistentFlags().StringP("relation-tag", "r", "-", "Relationship struct tag name")
   113  	rootCmd.PersistentFlags().StringSliceP("tag-ignore", "", nil, "List of column names that should have tags values set to '-' (ignored during parsing)")
   114  
   115  	// hide flags not recommended for use
   116  	rootCmd.PersistentFlags().MarkHidden("replace")
   117  
   118  	viper.BindPFlags(rootCmd.PersistentFlags())
   119  	viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
   120  	viper.AutomaticEnv()
   121  
   122  	if err := rootCmd.Execute(); err != nil {
   123  		if e, ok := err.(commandFailure); ok {
   124  			fmt.Printf("Error: %v\n\n", string(e))
   125  			rootCmd.Help()
   126  		} else if !viper.GetBool("debug") {
   127  			fmt.Printf("Error: %v\n", err)
   128  		} else {
   129  			fmt.Printf("Error: %+v\n", err)
   130  		}
   131  
   132  		os.Exit(1)
   133  	}
   134  }
   135  
   136  type commandFailure string
   137  
   138  func (c commandFailure) Error() string {
   139  	return string(c)
   140  }
   141  
   142  func preRun(cmd *cobra.Command, args []string) error {
   143  	var err error
   144  
   145  	if len(args) == 0 {
   146  		return commandFailure("must provide a driver name")
   147  	}
   148  
   149  	driverName, driverPath, err := drivers.RegisterBinaryFromCmdArg(args[0])
   150  	if err != nil {
   151  		return errors.Wrap(err, "could not register driver")
   152  	}
   153  
   154  	cmdConfig = &boilingcore.Config{
   155  		DriverName:        driverName,
   156  		OutFolder:         viper.GetString("output"),
   157  		PkgName:           viper.GetString("pkgname"),
   158  		Debug:             viper.GetBool("debug"),
   159  		AddGlobal:         viper.GetBool("add-global-variants"),
   160  		AddPanic:          viper.GetBool("add-panic-variants"),
   161  		AddSoftDeletes:    viper.GetBool("add-soft-deletes"),
   162  		AddEnumTypes:      viper.GetBool("add-enum-types"),
   163  		EnumNullPrefix:    viper.GetString("enum-null-prefix"),
   164  		NoContext:         viper.GetBool("no-context"),
   165  		NoTests:           viper.GetBool("no-tests"),
   166  		NoHooks:           viper.GetBool("no-hooks"),
   167  		NoRowsAffected:    viper.GetBool("no-rows-affected"),
   168  		NoAutoTimestamps:  viper.GetBool("no-auto-timestamps"),
   169  		NoDriverTemplates: viper.GetBool("no-driver-templates"),
   170  		NoBackReferencing: viper.GetBool("no-back-referencing"),
   171  		AlwaysWrapErrors:  viper.GetBool("always-wrap-errors"),
   172  		Wipe:              viper.GetBool("wipe"),
   173  		StructTagCasing:   strings.ToLower(viper.GetString("struct-tag-casing")), // camel | snake | title
   174  		TagIgnore:         viper.GetStringSlice("tag-ignore"),
   175  		RelationTag:       viper.GetString("relation-tag"),
   176  		TemplateDirs:      viper.GetStringSlice("templates"),
   177  		Tags:              viper.GetStringSlice("tag"),
   178  		Replacements:      viper.GetStringSlice("replace"),
   179  		Aliases:           boilingcore.ConvertAliases(viper.Get("aliases")),
   180  		TypeReplaces:      boilingcore.ConvertTypeReplace(viper.Get("types")),
   181  		AutoColumns: boilingcore.AutoColumns{
   182  			Created: viper.GetString("auto-columns.created"),
   183  			Updated: viper.GetString("auto-columns.updated"),
   184  			Deleted: viper.GetString("auto-columns.deleted"),
   185  		},
   186  		Inflections: boilingcore.Inflections{
   187  			Plural:        viper.GetStringMapString("inflections.plural"),
   188  			PluralExact:   viper.GetStringMapString("inflections.plural_exact"),
   189  			Singular:      viper.GetStringMapString("inflections.singular"),
   190  			SingularExact: viper.GetStringMapString("inflections.singular_exact"),
   191  			Irregular:     viper.GetStringMapString("inflections.irregular"),
   192  		},
   193  
   194  		Version: sqlBoilerVersion,
   195  	}
   196  
   197  	if cmdConfig.Debug {
   198  		fmt.Fprintln(os.Stderr, "using driver:", driverPath)
   199  	}
   200  
   201  	// Configure the driver
   202  	cmdConfig.DriverConfig = map[string]interface{}{
   203  		"whitelist":        viper.GetStringSlice(driverName + ".whitelist"),
   204  		"blacklist":        viper.GetStringSlice(driverName + ".blacklist"),
   205  		"add-enum-types":   cmdConfig.AddEnumTypes,
   206  		"enum-null-prefix": cmdConfig.EnumNullPrefix,
   207  	}
   208  
   209  	keys := allKeys(driverName)
   210  	for _, key := range keys {
   211  		if key != "blacklist" && key != "whitelist" {
   212  			prefixedKey := fmt.Sprintf("%s.%s", driverName, key)
   213  			cmdConfig.DriverConfig[key] = viper.Get(prefixedKey)
   214  		}
   215  	}
   216  
   217  	cmdConfig.Imports = configureImports()
   218  
   219  	cmdState, err = boilingcore.New(cmdConfig)
   220  	return err
   221  }
   222  
   223  func configureImports() importers.Collection {
   224  	imports := importers.NewDefaultImports()
   225  
   226  	mustMap := func(m importers.Map, err error) importers.Map {
   227  		if err != nil {
   228  			panic("failed to change viper interface into importers.Map: " + err.Error())
   229  		}
   230  
   231  		return m
   232  	}
   233  
   234  	if viper.IsSet("imports.all.standard") {
   235  		imports.All.Standard = viper.GetStringSlice("imports.all.standard")
   236  	}
   237  	if viper.IsSet("imports.all.third_party") {
   238  		imports.All.ThirdParty = viper.GetStringSlice("imports.all.third_party")
   239  	}
   240  	if viper.IsSet("imports.test.standard") {
   241  		imports.Test.Standard = viper.GetStringSlice("imports.test.standard")
   242  	}
   243  	if viper.IsSet("imports.test.third_party") {
   244  		imports.Test.ThirdParty = viper.GetStringSlice("imports.test.third_party")
   245  	}
   246  	if viper.IsSet("imports.singleton") {
   247  		imports.Singleton = mustMap(importers.MapFromInterface(viper.Get("imports.singleton")))
   248  	}
   249  	if viper.IsSet("imports.test_singleton") {
   250  		imports.TestSingleton = mustMap(importers.MapFromInterface(viper.Get("imports.test_singleton")))
   251  	}
   252  	if viper.IsSet("imports.based_on_type") {
   253  		imports.BasedOnType = mustMap(importers.MapFromInterface(viper.Get("imports.based_on_type")))
   254  	}
   255  
   256  	return imports
   257  }
   258  
   259  func run(cmd *cobra.Command, args []string) error {
   260  	return cmdState.Run()
   261  }
   262  
   263  func postRun(cmd *cobra.Command, args []string) error {
   264  	return cmdState.Cleanup()
   265  }
   266  
   267  func allKeys(prefix string) []string {
   268  	keys := make(map[string]bool)
   269  
   270  	prefix += "."
   271  
   272  	for _, e := range os.Environ() {
   273  		splits := strings.SplitN(e, "=", 2)
   274  		key := strings.ReplaceAll(strings.ToLower(splits[0]), "_", ".")
   275  
   276  		if strings.HasPrefix(key, prefix) {
   277  			keys[strings.ReplaceAll(key, prefix, "")] = true
   278  		}
   279  	}
   280  
   281  	for _, key := range viper.AllKeys() {
   282  		if strings.HasPrefix(key, prefix) {
   283  			keys[strings.ReplaceAll(key, prefix, "")] = true
   284  		}
   285  	}
   286  
   287  	keySlice := make([]string, 0, len(keys))
   288  	for k := range keys {
   289  		keySlice = append(keySlice, k)
   290  	}
   291  	return keySlice
   292  }