github.com/Ali-iotechsys/sqlboiler/v4@v4.0.0-20221208124957-6aec9a5f1f71/main.go (about) 1 // Package main defines a command line interface for the sqlboiler package 2 package main 3 4 import ( 5 "fmt" 6 "os" 7 "path/filepath" 8 "strings" 9 10 "github.com/friendsofgo/errors" 11 "github.com/spf13/cobra" 12 "github.com/spf13/viper" 13 "github.com/volatiletech/sqlboiler/v4/boilingcore" 14 "github.com/volatiletech/sqlboiler/v4/drivers" 15 "github.com/volatiletech/sqlboiler/v4/importers" 16 ) 17 18 const sqlBoilerVersion = "IOTech" //"4.13.0" 19 20 var ( 21 flagConfigFile string 22 cmdState *boilingcore.State 23 cmdConfig *boilingcore.Config 24 ) 25 26 func initConfig() { 27 if len(flagConfigFile) != 0 { 28 viper.SetConfigFile(flagConfigFile) 29 if err := viper.ReadInConfig(); err != nil { 30 fmt.Println("Can't read config:", err) 31 os.Exit(1) 32 } 33 return 34 } 35 36 var err error 37 viper.SetConfigName("sqlboiler") 38 39 configHome := os.Getenv("XDG_CONFIG_HOME") 40 homePath := os.Getenv("HOME") 41 wd, err := os.Getwd() 42 if err != nil { 43 wd = "." 44 } 45 46 configPaths := []string{wd} 47 if len(configHome) > 0 { 48 configPaths = append(configPaths, filepath.Join(configHome, "sqlboiler")) 49 } else { 50 configPaths = append(configPaths, filepath.Join(homePath, ".config/sqlboiler")) 51 } 52 53 for _, p := range configPaths { 54 viper.AddConfigPath(p) 55 } 56 57 // Ignore errors here, fallback to other validation methods. 58 // Users can use environment variables if a config is not found. 59 _ = viper.ReadInConfig() 60 } 61 62 func main() { 63 // Too much happens between here and cobra's argument handling, for 64 // something so simple just do it immediately. 65 for _, arg := range os.Args { 66 if arg == "--version" { 67 fmt.Println("SQLBoiler v" + sqlBoilerVersion) 68 return 69 } 70 } 71 72 // Set up the cobra root command 73 rootCmd := &cobra.Command{ 74 Use: "sqlboiler [flags] <driver>", 75 Short: "SQL Boiler generates an ORM tailored to your database schema.", 76 Long: "SQL Boiler generates a Go ORM from template files, tailored to your database schema.\n" + 77 `Complete documentation is available at http://github.com/volatiletech/sqlboiler`, 78 Example: `sqlboiler psql`, 79 PreRunE: preRun, 80 RunE: run, 81 PostRunE: postRun, 82 SilenceErrors: true, 83 SilenceUsage: true, 84 } 85 86 cobra.OnInitialize(initConfig) 87 88 // Set up the cobra root command flags 89 rootCmd.PersistentFlags().StringVarP(&flagConfigFile, "config", "c", "", "Filename of config file to override default lookup") 90 rootCmd.PersistentFlags().StringP("output", "o", "models", "The name of the folder to output to") 91 rootCmd.PersistentFlags().StringP("pkgname", "p", "models", "The name you wish to assign to your generated package") 92 rootCmd.PersistentFlags().StringSliceP("templates", "", nil, "A templates directory, overrides the embedded template folders in sqlboiler") 93 rootCmd.PersistentFlags().StringSliceP("tag", "t", nil, "Struct tags to be included on your models in addition to json, yaml, toml") 94 rootCmd.PersistentFlags().StringSliceP("replace", "", nil, "Replace templates by directory: relpath/to_file.tpl:relpath/to_replacement.tpl") 95 rootCmd.PersistentFlags().BoolP("debug", "d", false, "Debug mode prints stack traces on error") 96 rootCmd.PersistentFlags().BoolP("no-context", "", false, "Disable context.Context usage in the generated code") 97 rootCmd.PersistentFlags().BoolP("no-tests", "", false, "Disable generated go test files") 98 rootCmd.PersistentFlags().BoolP("no-hooks", "", false, "Disable hooks feature for your models") 99 rootCmd.PersistentFlags().BoolP("no-rows-affected", "", false, "Disable rows affected in the generated API") 100 rootCmd.PersistentFlags().BoolP("no-auto-timestamps", "", false, "Disable automatic timestamps for created_at/updated_at") 101 rootCmd.PersistentFlags().BoolP("no-driver-templates", "", false, "Disable parsing of templates defined by the database driver") 102 rootCmd.PersistentFlags().BoolP("no-back-referencing", "", false, "Disable back referencing in the loaded relationship structs") 103 rootCmd.PersistentFlags().BoolP("always-wrap-errors", "", false, "Wrap all returned errors with stacktraces, also sql.ErrNoRows") 104 rootCmd.PersistentFlags().BoolP("add-global-variants", "", false, "Enable generation for global variants") 105 rootCmd.PersistentFlags().BoolP("add-panic-variants", "", false, "Enable generation for panic variants") 106 rootCmd.PersistentFlags().BoolP("add-soft-deletes", "", false, "Enable soft deletion by updating deleted_at timestamp") 107 rootCmd.PersistentFlags().BoolP("add-enum-types", "", false, "Enable generation of types for enums") 108 rootCmd.PersistentFlags().StringP("enum-null-prefix", "", "Null", "Name prefix of nullable enum types") 109 rootCmd.PersistentFlags().BoolP("version", "", false, "Print the version") 110 rootCmd.PersistentFlags().BoolP("wipe", "", false, "Delete the output folder (rm -rf) before generation to ensure sanity") 111 rootCmd.PersistentFlags().StringP("struct-tag-casing", "", "snake", "Decides the casing for go structure tag names. camel, title or snake (default snake)") 112 rootCmd.PersistentFlags().StringP("relation-tag", "r", "-", "Relationship struct tag name") 113 rootCmd.PersistentFlags().StringSliceP("tag-ignore", "", nil, "List of column names that should have tags values set to '-' (ignored during parsing)") 114 115 // hide flags not recommended for use 116 rootCmd.PersistentFlags().MarkHidden("replace") 117 118 viper.BindPFlags(rootCmd.PersistentFlags()) 119 viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) 120 viper.AutomaticEnv() 121 122 if err := rootCmd.Execute(); err != nil { 123 if e, ok := err.(commandFailure); ok { 124 fmt.Printf("Error: %v\n\n", string(e)) 125 rootCmd.Help() 126 } else if !viper.GetBool("debug") { 127 fmt.Printf("Error: %v\n", err) 128 } else { 129 fmt.Printf("Error: %+v\n", err) 130 } 131 132 os.Exit(1) 133 } 134 } 135 136 type commandFailure string 137 138 func (c commandFailure) Error() string { 139 return string(c) 140 } 141 142 func preRun(cmd *cobra.Command, args []string) error { 143 var err error 144 145 if len(args) == 0 { 146 return commandFailure("must provide a driver name") 147 } 148 149 driverName, driverPath, err := drivers.RegisterBinaryFromCmdArg(args[0]) 150 if err != nil { 151 return errors.Wrap(err, "could not register driver") 152 } 153 154 cmdConfig = &boilingcore.Config{ 155 DriverName: driverName, 156 OutFolder: viper.GetString("output"), 157 PkgName: viper.GetString("pkgname"), 158 Debug: viper.GetBool("debug"), 159 AddGlobal: viper.GetBool("add-global-variants"), 160 AddPanic: viper.GetBool("add-panic-variants"), 161 AddSoftDeletes: viper.GetBool("add-soft-deletes"), 162 AddEnumTypes: viper.GetBool("add-enum-types"), 163 EnumNullPrefix: viper.GetString("enum-null-prefix"), 164 NoContext: viper.GetBool("no-context"), 165 NoTests: viper.GetBool("no-tests"), 166 NoHooks: viper.GetBool("no-hooks"), 167 NoRowsAffected: viper.GetBool("no-rows-affected"), 168 NoAutoTimestamps: viper.GetBool("no-auto-timestamps"), 169 NoDriverTemplates: viper.GetBool("no-driver-templates"), 170 NoBackReferencing: viper.GetBool("no-back-referencing"), 171 AlwaysWrapErrors: viper.GetBool("always-wrap-errors"), 172 Wipe: viper.GetBool("wipe"), 173 StructTagCasing: strings.ToLower(viper.GetString("struct-tag-casing")), // camel | snake | title 174 TagIgnore: viper.GetStringSlice("tag-ignore"), 175 RelationTag: viper.GetString("relation-tag"), 176 TemplateDirs: viper.GetStringSlice("templates"), 177 Tags: viper.GetStringSlice("tag"), 178 Replacements: viper.GetStringSlice("replace"), 179 Aliases: boilingcore.ConvertAliases(viper.Get("aliases")), 180 TypeReplaces: boilingcore.ConvertTypeReplace(viper.Get("types")), 181 AutoColumns: boilingcore.AutoColumns{ 182 Created: viper.GetString("auto-columns.created"), 183 Updated: viper.GetString("auto-columns.updated"), 184 Deleted: viper.GetString("auto-columns.deleted"), 185 }, 186 Inflections: boilingcore.Inflections{ 187 Plural: viper.GetStringMapString("inflections.plural"), 188 PluralExact: viper.GetStringMapString("inflections.plural_exact"), 189 Singular: viper.GetStringMapString("inflections.singular"), 190 SingularExact: viper.GetStringMapString("inflections.singular_exact"), 191 Irregular: viper.GetStringMapString("inflections.irregular"), 192 }, 193 194 Version: sqlBoilerVersion, 195 } 196 197 if cmdConfig.Debug { 198 fmt.Fprintln(os.Stderr, "using driver:", driverPath) 199 } 200 201 // Configure the driver 202 cmdConfig.DriverConfig = map[string]interface{}{ 203 "whitelist": viper.GetStringSlice(driverName + ".whitelist"), 204 "blacklist": viper.GetStringSlice(driverName + ".blacklist"), 205 "add-enum-types": cmdConfig.AddEnumTypes, 206 "enum-null-prefix": cmdConfig.EnumNullPrefix, 207 } 208 209 keys := allKeys(driverName) 210 for _, key := range keys { 211 if key != "blacklist" && key != "whitelist" { 212 prefixedKey := fmt.Sprintf("%s.%s", driverName, key) 213 cmdConfig.DriverConfig[key] = viper.Get(prefixedKey) 214 } 215 } 216 217 cmdConfig.Imports = configureImports() 218 219 cmdState, err = boilingcore.New(cmdConfig) 220 return err 221 } 222 223 func configureImports() importers.Collection { 224 imports := importers.NewDefaultImports() 225 226 mustMap := func(m importers.Map, err error) importers.Map { 227 if err != nil { 228 panic("failed to change viper interface into importers.Map: " + err.Error()) 229 } 230 231 return m 232 } 233 234 if viper.IsSet("imports.all.standard") { 235 imports.All.Standard = viper.GetStringSlice("imports.all.standard") 236 } 237 if viper.IsSet("imports.all.third_party") { 238 imports.All.ThirdParty = viper.GetStringSlice("imports.all.third_party") 239 } 240 if viper.IsSet("imports.test.standard") { 241 imports.Test.Standard = viper.GetStringSlice("imports.test.standard") 242 } 243 if viper.IsSet("imports.test.third_party") { 244 imports.Test.ThirdParty = viper.GetStringSlice("imports.test.third_party") 245 } 246 if viper.IsSet("imports.singleton") { 247 imports.Singleton = mustMap(importers.MapFromInterface(viper.Get("imports.singleton"))) 248 } 249 if viper.IsSet("imports.test_singleton") { 250 imports.TestSingleton = mustMap(importers.MapFromInterface(viper.Get("imports.test_singleton"))) 251 } 252 if viper.IsSet("imports.based_on_type") { 253 imports.BasedOnType = mustMap(importers.MapFromInterface(viper.Get("imports.based_on_type"))) 254 } 255 256 return imports 257 } 258 259 func run(cmd *cobra.Command, args []string) error { 260 return cmdState.Run() 261 } 262 263 func postRun(cmd *cobra.Command, args []string) error { 264 return cmdState.Cleanup() 265 } 266 267 func allKeys(prefix string) []string { 268 keys := make(map[string]bool) 269 270 prefix += "." 271 272 for _, e := range os.Environ() { 273 splits := strings.SplitN(e, "=", 2) 274 key := strings.ReplaceAll(strings.ToLower(splits[0]), "_", ".") 275 276 if strings.HasPrefix(key, prefix) { 277 keys[strings.ReplaceAll(key, prefix, "")] = true 278 } 279 } 280 281 for _, key := range viper.AllKeys() { 282 if strings.HasPrefix(key, prefix) { 283 keys[strings.ReplaceAll(key, prefix, "")] = true 284 } 285 } 286 287 keySlice := make([]string, 0, len(keys)) 288 for k := range keys { 289 keySlice = append(keySlice, k) 290 } 291 return keySlice 292 }