go.mondoo.com/cnquery@v0.0.0-20231005093811-59568235f6ea/cli/providers/providers.go (about) 1 // Copyright (c) Mondoo, Inc. 2 // SPDX-License-Identifier: BUSL-1.1 3 4 package providers 5 6 import ( 7 "encoding/json" 8 "os" 9 "strings" 10 11 "github.com/rs/zerolog/log" 12 "github.com/spf13/cobra" 13 "github.com/spf13/pflag" 14 "github.com/spf13/viper" 15 "go.mondoo.com/cnquery/cli/components" 16 "go.mondoo.com/cnquery/cli/config" 17 "go.mondoo.com/cnquery/llx" 18 "go.mondoo.com/cnquery/providers" 19 "go.mondoo.com/cnquery/providers-sdk/v1/plugin" 20 "go.mondoo.com/cnquery/types" 21 ) 22 23 type Command struct { 24 Command *cobra.Command 25 Run func(*cobra.Command, *providers.Runtime, *plugin.ParseCLIRes) 26 Action string 27 } 28 29 // AttachCLIs will attempt to parse the current commandline and look for providers. 30 // This step is done before cobra ever takes effect 31 func AttachCLIs(rootCmd *cobra.Command, commands ...*Command) error { 32 existing, err := providers.ListActive() 33 if err != nil { 34 return err 35 } 36 37 connectorName, autoUpdate := detectConnectorName(os.Args, rootCmd, commands, existing) 38 if connectorName != "" { 39 if _, err := providers.EnsureProvider(connectorName, "", autoUpdate, existing); err != nil { 40 return err 41 } 42 } 43 44 // Now that we know we have all providers, it's time to load them to build 45 // the remaining CLI. Probably an opportunity to optimize in the future, 46 // but fine for now to do it all. 47 48 attachProviders(existing, commands) 49 return nil 50 } 51 52 func detectConnectorName(args []string, rootCmd *cobra.Command, commands []*Command, providers providers.Providers) (string, bool) { 53 autoUpdate := true 54 55 config.InitViperConfig() 56 if viper.IsSet("auto_update") { 57 autoUpdate = viper.GetBool("auto_update") 58 } 59 60 flags := pflag.NewFlagSet("set", pflag.ContinueOnError) 61 flags.ParseErrorsWhitelist.UnknownFlags = true 62 flags.Bool("auto-update", autoUpdate, "") 63 flags.BoolP("help", "h", false, "") 64 65 builtins := genBuiltinFlags() 66 for i := range builtins { 67 attachFlag(flags, builtins[i]) 68 } 69 70 // To avoid warnings about flags, we need to mock all flags on the root command. 71 // The command after root (eg: run, scan, shell, ...) are normal actions that 72 // we want to detect. We only need to add flags from the root command and the 73 // attaching subcommands (since those are the ones which end up giving us 74 // the connector) 75 attachPFlags(flags, rootCmd.Flags()) 76 attachPFlags(flags, rootCmd.PersistentFlags()) 77 78 for i := range commands { 79 cmd := commands[i] 80 attachPFlags(flags, cmd.Command.Flags()) 81 } 82 83 for i := range providers { 84 provider := providers[i] 85 for j := range provider.Connectors { 86 conn := provider.Connectors[j] 87 for k := range conn.Flags { 88 flag := conn.Flags[k] 89 if found := flags.Lookup(flag.Long); found == nil { 90 attachFlag(flags, flag) 91 } 92 } 93 } 94 } 95 96 err := flags.Parse(args) 97 if err != nil { 98 log.Warn().Err(err).Msg("CLI pre-processing encountered an issue") 99 } 100 101 autoUpdate, _ = flags.GetBool("auto-update") 102 103 parsedArgs := flags.Args() 104 if len(parsedArgs) <= 1 { 105 return "", autoUpdate 106 } 107 108 commandFound := false 109 for j := range commands { 110 if commands[j].Command.Use == parsedArgs[1] { 111 commandFound = true 112 break 113 } 114 } 115 if !commandFound { 116 return "", autoUpdate 117 } 118 119 // since we have a known command, we can now expect the connector to be 120 // local by default if nothing else is set 121 if len(parsedArgs) == 2 { 122 return "local", autoUpdate 123 } 124 125 connector := parsedArgs[2] 126 // we may want to double-check if the connector exists 127 128 return connector, autoUpdate 129 } 130 131 func attachProviders(existing providers.Providers, commands []*Command) { 132 for i := range commands { 133 attachProvidersToCmd(existing, commands[i]) 134 } 135 } 136 137 func attachProvidersToCmd(existing providers.Providers, cmd *Command) { 138 for _, provider := range existing { 139 for j := range provider.Connectors { 140 conn := provider.Connectors[j] 141 attachConnectorCmd(provider.Provider, &conn, cmd) 142 for k := range conn.Aliases { 143 copyConn := conn 144 copyConn.Name = conn.Aliases[k] 145 attachConnectorCmd(provider.Provider, ©Conn, cmd) 146 } 147 } 148 } 149 150 // the default is always os.local if it exists 151 if p, ok := existing[providers.DefaultOsID]; ok { 152 for i := range p.Connectors { 153 c := p.Connectors[i] 154 if c.Name == "local" { 155 setDefaultConnector(p.Provider, &c, cmd) 156 break 157 } 158 } 159 } 160 } 161 162 func setDefaultConnector(provider *plugin.Provider, connector *plugin.Connector, cmd *Command) { 163 cmd.Command.Run = func(cmd *cobra.Command, args []string) { 164 if len(args) > 0 { 165 log.Error().Msg("provider " + args[0] + " does not exist") 166 cmd.Help() 167 os.Exit(1) 168 } 169 170 log.Info().Msg("no provider specified, defaulting to local. Use --help to see all providers.") 171 } 172 cmd.Command.Short = cmd.Action + connector.Short 173 174 setConnector(provider, connector, cmd.Run, cmd.Command) 175 } 176 177 func attachConnectorCmd(provider *plugin.Provider, connector *plugin.Connector, cmd *Command) { 178 res := &cobra.Command{ 179 Use: connector.Use, 180 Short: cmd.Action + connector.Short, 181 Long: connector.Long, 182 Aliases: connector.Aliases, 183 PreRun: cmd.Command.PreRun, 184 } 185 186 if connector.MinArgs == connector.MaxArgs { 187 if connector.MinArgs == 0 { 188 res.Args = cobra.NoArgs 189 } else { 190 res.Args = cobra.ExactArgs(int(connector.MinArgs)) 191 } 192 } else { 193 if connector.MaxArgs > 0 && connector.MinArgs == 0 { 194 res.Args = cobra.MaximumNArgs(int(connector.MaxArgs)) 195 } else if connector.MaxArgs == 0 && connector.MinArgs > 0 { 196 res.Args = cobra.MinimumNArgs(int(connector.MinArgs)) 197 } else { 198 res.Args = cobra.RangeArgs(int(connector.MinArgs), int(connector.MaxArgs)) 199 } 200 } 201 cmd.Command.Flags().VisitAll(func(flag *pflag.Flag) { 202 res.Flags().AddFlag(flag) 203 }) 204 205 cmd.Command.AddCommand(res) 206 setConnector(provider, connector, cmd.Run, res) 207 } 208 209 func genBuiltinFlags(discoveries ...string) []plugin.Flag { 210 supportedDiscoveries := append([]string{"all", "auto"}, discoveries...) 211 212 return []plugin.Flag{ 213 // flags for providers: 214 { 215 Long: "discover", 216 Type: plugin.FlagType_List, 217 Desc: "Enable the discovery of nested assets. Supports: " + strings.Join(supportedDiscoveries, ","), 218 }, 219 { 220 Long: "pretty", 221 Type: plugin.FlagType_Bool, 222 Desc: "Pretty-print JSON", 223 Option: plugin.FlagOption_Hidden, 224 }, 225 // runtime-only flags: 226 { 227 Long: "record", 228 Type: plugin.FlagType_String, 229 Desc: "Record all resource calls and use resources in the recording", 230 }, 231 { 232 Long: "use-recording", 233 Type: plugin.FlagType_String, 234 Desc: "Use a recording to inject resource data (read-only)", 235 }, 236 } 237 } 238 239 // the following flags are not processed by providers 240 var skipFlags = map[string]struct{}{ 241 "ask-pass": {}, 242 "record": {}, 243 "use-recording": {}, 244 } 245 246 func attachPFlags(base *pflag.FlagSet, nu *pflag.FlagSet) { 247 nu.VisitAll(func(flag *pflag.Flag) { 248 if found := base.Lookup(flag.Name); found != nil { 249 return 250 } 251 if flag.Shorthand != "" { 252 if found := base.ShorthandLookup(flag.Shorthand); found != nil { 253 return 254 } 255 } 256 base.AddFlag(flag) 257 }) 258 } 259 260 func attachFlag(flagset *pflag.FlagSet, flag plugin.Flag) { 261 switch flag.Type { 262 case plugin.FlagType_Bool: 263 if flag.Short != "" { 264 flagset.BoolP(flag.Long, flag.Short, json2T(flag.Default, false), flag.Desc) 265 } else { 266 flagset.Bool(flag.Long, json2T(flag.Default, false), flag.Desc) 267 } 268 case plugin.FlagType_Int: 269 if flag.Short != "" { 270 flagset.IntP(flag.Long, flag.Short, json2T(flag.Default, 0), flag.Desc) 271 } else { 272 flagset.Int(flag.Long, json2T(flag.Default, 0), flag.Desc) 273 } 274 case plugin.FlagType_String: 275 if flag.Short != "" { 276 flagset.StringP(flag.Long, flag.Short, flag.Default, flag.Desc) 277 } else { 278 flagset.String(flag.Long, flag.Default, flag.Desc) 279 } 280 case plugin.FlagType_List: 281 if flag.Short != "" { 282 flagset.StringSliceP(flag.Long, flag.Short, json2T(flag.Default, []string{}), flag.Desc) 283 } else { 284 flagset.StringSlice(flag.Long, json2T(flag.Default, []string{}), flag.Desc) 285 } 286 case plugin.FlagType_KeyValue: 287 if flag.Short != "" { 288 flagset.StringToStringP(flag.Long, flag.Short, json2T(flag.Default, map[string]string{}), flag.Desc) 289 } else { 290 flagset.StringToString(flag.Long, json2T(flag.Default, map[string]string{}), flag.Desc) 291 } 292 } 293 294 if flag.Option&plugin.FlagOption_Hidden != 0 { 295 flagset.MarkHidden(flag.Long) 296 } 297 if flag.Option&plugin.FlagOption_Deprecated != 0 { 298 flagset.MarkDeprecated(flag.Long, "has been deprecated") 299 } 300 } 301 302 func attachFlags(flagset *pflag.FlagSet, flags []plugin.Flag) { 303 for i := range flags { 304 attachFlag(flagset, flags[i]) 305 } 306 } 307 308 func getFlagValue(flag plugin.Flag, cmd *cobra.Command) *llx.Primitive { 309 switch flag.Type { 310 case plugin.FlagType_Bool: 311 v, err := cmd.Flags().GetBool(flag.Long) 312 if err == nil { 313 return llx.BoolPrimitive(v) 314 } 315 log.Warn().Err(err).Msg("failed to get flag " + flag.Long) 316 case plugin.FlagType_Int: 317 if v, err := cmd.Flags().GetInt(flag.Long); err == nil { 318 return llx.IntPrimitive(int64(v)) 319 } 320 case plugin.FlagType_String: 321 if v, err := cmd.Flags().GetString(flag.Long); err == nil { 322 return llx.StringPrimitive(v) 323 } 324 case plugin.FlagType_List: 325 if v, err := cmd.Flags().GetStringSlice(flag.Long); err == nil { 326 return llx.ArrayPrimitiveT(v, llx.StringPrimitive, types.String) 327 } 328 case plugin.FlagType_KeyValue: 329 if v, err := cmd.Flags().GetStringToString(flag.Long); err == nil { 330 return llx.MapPrimitiveT(v, llx.StringPrimitive, types.String) 331 } 332 default: 333 log.Warn().Msg("unknown flag type for " + flag.Long) 334 return nil 335 } 336 return nil 337 } 338 339 func setConnector(provider *plugin.Provider, connector *plugin.Connector, run func(*cobra.Command, *providers.Runtime, *plugin.ParseCLIRes), cmd *cobra.Command) { 340 oldRun := cmd.Run 341 oldPreRun := cmd.PreRun 342 343 builtinFlags := genBuiltinFlags(connector.Discovery...) 344 allFlags := append(connector.Flags, builtinFlags...) 345 346 cmd.PreRun = func(cc *cobra.Command, args []string) { 347 if oldPreRun != nil { 348 oldPreRun(cc, args) 349 } 350 351 // Config options need to be connected to flags before the Run begins. 352 // Flags are provided by the connector. 353 for i := range allFlags { 354 flag := allFlags[i] 355 if flag.ConfigEntry == "-" { 356 continue 357 } 358 359 flagName := flag.ConfigEntry 360 if flagName == "" { 361 flagName = flag.Long 362 } 363 364 viper.BindPFlag(flagName, cmd.Flags().Lookup(flag.Long)) 365 } 366 } 367 368 cmd.Run = func(cc *cobra.Command, args []string) { 369 if oldRun != nil { 370 oldRun(cc, args) 371 } 372 373 log.Debug().Msg("using provider " + provider.Name + " with connector " + connector.Name) 374 375 // TODO: replace this hard-coded block. This should be dynamic for all 376 // fields that are specified to be passwords with the --ask-field 377 // associated with it to make it simple. 378 // check if the user used --password without a value 379 askPass, err := cc.Flags().GetBool("ask-pass") 380 if err == nil && askPass { 381 pass, err := components.AskPassword("Enter password: ") 382 if err != nil { 383 log.Fatal().Err(err).Msg("failed to get password") 384 } 385 cc.Flags().Set("password", pass) 386 } 387 // ^^ 388 389 useRecording, err := cc.Flags().GetString("use-recording") 390 if err != nil { 391 log.Warn().Msg("failed to get flag --recording") 392 } 393 record, err := cc.Flags().GetString("record") 394 if err != nil { 395 log.Warn().Msg("failed to get flag --record") 396 } 397 pretty, err := cc.Flags().GetBool("pretty") 398 if err != nil { 399 log.Warn().Msg("failed to get flag --pretty") 400 } 401 402 flagVals := map[string]*llx.Primitive{} 403 for i := range allFlags { 404 flag := allFlags[i] 405 406 // we skip these because they are coded above 407 if _, skip := skipFlags[flag.Long]; skip { 408 continue 409 } 410 411 if v := getFlagValue(flag, cmd); v != nil { 412 flagVals[flag.Long] = v 413 } 414 } 415 416 // TODO: add flag to set timeout and then use RuntimeWithShutdownTimeout 417 runtime := providers.Coordinator.NewRuntime() 418 if err = providers.SetDefaultRuntime(runtime); err != nil { 419 log.Error().Msg(err.Error()) 420 } 421 422 autoUpdate := true 423 if viper.IsSet("auto_update") { 424 autoUpdate = viper.GetBool("auto_update") 425 } 426 427 runtime.AutoUpdate = providers.UpdateProvidersConfig{ 428 Enabled: autoUpdate, 429 RefreshInterval: 60 * 60, 430 } 431 432 if err := runtime.UseProvider(provider.ID); err != nil { 433 providers.Coordinator.Shutdown() 434 log.Fatal().Err(err).Msg("failed to start provider " + provider.Name) 435 } 436 437 if record != "" && useRecording != "" { 438 log.Fatal().Msg("please only use --record or --use-recording, but not both at the same time") 439 } 440 recordingPath := record 441 if recordingPath == "" { 442 recordingPath = useRecording 443 } 444 doRecord := record != "" 445 446 recording, err := providers.NewRecording(recordingPath, providers.RecordingOptions{ 447 DoRecord: doRecord, 448 PrettyPrintJSON: pretty, 449 }) 450 if err != nil { 451 log.Fatal().Msg(err.Error()) 452 } 453 runtime.SetRecording(recording) 454 455 cliRes, err := runtime.Provider.Instance.Plugin.ParseCLI(&plugin.ParseCLIReq{ 456 Connector: connector.Name, 457 Args: args, 458 Flags: flagVals, 459 }) 460 if err != nil { 461 runtime.Close() 462 providers.Coordinator.Shutdown() 463 log.Fatal().Err(err).Msg("failed to parse cli arguments") 464 } 465 466 if cliRes == nil { 467 runtime.Close() 468 providers.Coordinator.Shutdown() 469 log.Fatal().Msg("failed to process CLI arguments, nothing was returned") 470 } 471 472 run(cc, runtime, cliRes) 473 runtime.Close() 474 providers.Coordinator.Shutdown() 475 } 476 477 attachFlags(cmd.Flags(), allFlags) 478 } 479 480 func json2T[T any](s string, empty T) T { 481 var res T 482 if err := json.Unmarshal([]byte(s), &res); err == nil { 483 return res 484 } 485 return empty 486 }