go.mondoo.com/cnquery@v0.0.0-20231005093811-59568235f6ea/cli/providers/providers.go (about)

     1  // Copyright (c) Mondoo, Inc.
     2  // SPDX-License-Identifier: BUSL-1.1
     3  
     4  package providers
     5  
     6  import (
     7  	"encoding/json"
     8  	"os"
     9  	"strings"
    10  
    11  	"github.com/rs/zerolog/log"
    12  	"github.com/spf13/cobra"
    13  	"github.com/spf13/pflag"
    14  	"github.com/spf13/viper"
    15  	"go.mondoo.com/cnquery/cli/components"
    16  	"go.mondoo.com/cnquery/cli/config"
    17  	"go.mondoo.com/cnquery/llx"
    18  	"go.mondoo.com/cnquery/providers"
    19  	"go.mondoo.com/cnquery/providers-sdk/v1/plugin"
    20  	"go.mondoo.com/cnquery/types"
    21  )
    22  
    23  type Command struct {
    24  	Command *cobra.Command
    25  	Run     func(*cobra.Command, *providers.Runtime, *plugin.ParseCLIRes)
    26  	Action  string
    27  }
    28  
    29  // AttachCLIs will attempt to parse the current commandline and look for providers.
    30  // This step is done before cobra ever takes effect
    31  func AttachCLIs(rootCmd *cobra.Command, commands ...*Command) error {
    32  	existing, err := providers.ListActive()
    33  	if err != nil {
    34  		return err
    35  	}
    36  
    37  	connectorName, autoUpdate := detectConnectorName(os.Args, rootCmd, commands, existing)
    38  	if connectorName != "" {
    39  		if _, err := providers.EnsureProvider(connectorName, "", autoUpdate, existing); err != nil {
    40  			return err
    41  		}
    42  	}
    43  
    44  	// Now that we know we have all providers, it's time to load them to build
    45  	// the remaining CLI. Probably an opportunity to optimize in the future,
    46  	// but fine for now to do it all.
    47  
    48  	attachProviders(existing, commands)
    49  	return nil
    50  }
    51  
    52  func detectConnectorName(args []string, rootCmd *cobra.Command, commands []*Command, providers providers.Providers) (string, bool) {
    53  	autoUpdate := true
    54  
    55  	config.InitViperConfig()
    56  	if viper.IsSet("auto_update") {
    57  		autoUpdate = viper.GetBool("auto_update")
    58  	}
    59  
    60  	flags := pflag.NewFlagSet("set", pflag.ContinueOnError)
    61  	flags.ParseErrorsWhitelist.UnknownFlags = true
    62  	flags.Bool("auto-update", autoUpdate, "")
    63  	flags.BoolP("help", "h", false, "")
    64  
    65  	builtins := genBuiltinFlags()
    66  	for i := range builtins {
    67  		attachFlag(flags, builtins[i])
    68  	}
    69  
    70  	// To avoid warnings about flags, we need to mock all flags on the root command.
    71  	// The command after root (eg: run, scan, shell, ...) are normal actions that
    72  	// we want to detect. We only need to add flags from the root command and the
    73  	// attaching subcommands (since those are the ones which end up giving us
    74  	// the connector)
    75  	attachPFlags(flags, rootCmd.Flags())
    76  	attachPFlags(flags, rootCmd.PersistentFlags())
    77  
    78  	for i := range commands {
    79  		cmd := commands[i]
    80  		attachPFlags(flags, cmd.Command.Flags())
    81  	}
    82  
    83  	for i := range providers {
    84  		provider := providers[i]
    85  		for j := range provider.Connectors {
    86  			conn := provider.Connectors[j]
    87  			for k := range conn.Flags {
    88  				flag := conn.Flags[k]
    89  				if found := flags.Lookup(flag.Long); found == nil {
    90  					attachFlag(flags, flag)
    91  				}
    92  			}
    93  		}
    94  	}
    95  
    96  	err := flags.Parse(args)
    97  	if err != nil {
    98  		log.Warn().Err(err).Msg("CLI pre-processing encountered an issue")
    99  	}
   100  
   101  	autoUpdate, _ = flags.GetBool("auto-update")
   102  
   103  	parsedArgs := flags.Args()
   104  	if len(parsedArgs) <= 1 {
   105  		return "", autoUpdate
   106  	}
   107  
   108  	commandFound := false
   109  	for j := range commands {
   110  		if commands[j].Command.Use == parsedArgs[1] {
   111  			commandFound = true
   112  			break
   113  		}
   114  	}
   115  	if !commandFound {
   116  		return "", autoUpdate
   117  	}
   118  
   119  	// since we have a known command, we can now expect the connector to be
   120  	// local by default if nothing else is set
   121  	if len(parsedArgs) == 2 {
   122  		return "local", autoUpdate
   123  	}
   124  
   125  	connector := parsedArgs[2]
   126  	// we may want to double-check if the connector exists
   127  
   128  	return connector, autoUpdate
   129  }
   130  
   131  func attachProviders(existing providers.Providers, commands []*Command) {
   132  	for i := range commands {
   133  		attachProvidersToCmd(existing, commands[i])
   134  	}
   135  }
   136  
   137  func attachProvidersToCmd(existing providers.Providers, cmd *Command) {
   138  	for _, provider := range existing {
   139  		for j := range provider.Connectors {
   140  			conn := provider.Connectors[j]
   141  			attachConnectorCmd(provider.Provider, &conn, cmd)
   142  			for k := range conn.Aliases {
   143  				copyConn := conn
   144  				copyConn.Name = conn.Aliases[k]
   145  				attachConnectorCmd(provider.Provider, &copyConn, cmd)
   146  			}
   147  		}
   148  	}
   149  
   150  	// the default is always os.local if it exists
   151  	if p, ok := existing[providers.DefaultOsID]; ok {
   152  		for i := range p.Connectors {
   153  			c := p.Connectors[i]
   154  			if c.Name == "local" {
   155  				setDefaultConnector(p.Provider, &c, cmd)
   156  				break
   157  			}
   158  		}
   159  	}
   160  }
   161  
   162  func setDefaultConnector(provider *plugin.Provider, connector *plugin.Connector, cmd *Command) {
   163  	cmd.Command.Run = func(cmd *cobra.Command, args []string) {
   164  		if len(args) > 0 {
   165  			log.Error().Msg("provider " + args[0] + " does not exist")
   166  			cmd.Help()
   167  			os.Exit(1)
   168  		}
   169  
   170  		log.Info().Msg("no provider specified, defaulting to local. Use --help to see all providers.")
   171  	}
   172  	cmd.Command.Short = cmd.Action + connector.Short
   173  
   174  	setConnector(provider, connector, cmd.Run, cmd.Command)
   175  }
   176  
   177  func attachConnectorCmd(provider *plugin.Provider, connector *plugin.Connector, cmd *Command) {
   178  	res := &cobra.Command{
   179  		Use:     connector.Use,
   180  		Short:   cmd.Action + connector.Short,
   181  		Long:    connector.Long,
   182  		Aliases: connector.Aliases,
   183  		PreRun:  cmd.Command.PreRun,
   184  	}
   185  
   186  	if connector.MinArgs == connector.MaxArgs {
   187  		if connector.MinArgs == 0 {
   188  			res.Args = cobra.NoArgs
   189  		} else {
   190  			res.Args = cobra.ExactArgs(int(connector.MinArgs))
   191  		}
   192  	} else {
   193  		if connector.MaxArgs > 0 && connector.MinArgs == 0 {
   194  			res.Args = cobra.MaximumNArgs(int(connector.MaxArgs))
   195  		} else if connector.MaxArgs == 0 && connector.MinArgs > 0 {
   196  			res.Args = cobra.MinimumNArgs(int(connector.MinArgs))
   197  		} else {
   198  			res.Args = cobra.RangeArgs(int(connector.MinArgs), int(connector.MaxArgs))
   199  		}
   200  	}
   201  	cmd.Command.Flags().VisitAll(func(flag *pflag.Flag) {
   202  		res.Flags().AddFlag(flag)
   203  	})
   204  
   205  	cmd.Command.AddCommand(res)
   206  	setConnector(provider, connector, cmd.Run, res)
   207  }
   208  
   209  func genBuiltinFlags(discoveries ...string) []plugin.Flag {
   210  	supportedDiscoveries := append([]string{"all", "auto"}, discoveries...)
   211  
   212  	return []plugin.Flag{
   213  		// flags for providers:
   214  		{
   215  			Long: "discover",
   216  			Type: plugin.FlagType_List,
   217  			Desc: "Enable the discovery of nested assets. Supports: " + strings.Join(supportedDiscoveries, ","),
   218  		},
   219  		{
   220  			Long:   "pretty",
   221  			Type:   plugin.FlagType_Bool,
   222  			Desc:   "Pretty-print JSON",
   223  			Option: plugin.FlagOption_Hidden,
   224  		},
   225  		// runtime-only flags:
   226  		{
   227  			Long: "record",
   228  			Type: plugin.FlagType_String,
   229  			Desc: "Record all resource calls and use resources in the recording",
   230  		},
   231  		{
   232  			Long: "use-recording",
   233  			Type: plugin.FlagType_String,
   234  			Desc: "Use a recording to inject resource data (read-only)",
   235  		},
   236  	}
   237  }
   238  
   239  // the following flags are not processed by providers
   240  var skipFlags = map[string]struct{}{
   241  	"ask-pass":      {},
   242  	"record":        {},
   243  	"use-recording": {},
   244  }
   245  
   246  func attachPFlags(base *pflag.FlagSet, nu *pflag.FlagSet) {
   247  	nu.VisitAll(func(flag *pflag.Flag) {
   248  		if found := base.Lookup(flag.Name); found != nil {
   249  			return
   250  		}
   251  		if flag.Shorthand != "" {
   252  			if found := base.ShorthandLookup(flag.Shorthand); found != nil {
   253  				return
   254  			}
   255  		}
   256  		base.AddFlag(flag)
   257  	})
   258  }
   259  
   260  func attachFlag(flagset *pflag.FlagSet, flag plugin.Flag) {
   261  	switch flag.Type {
   262  	case plugin.FlagType_Bool:
   263  		if flag.Short != "" {
   264  			flagset.BoolP(flag.Long, flag.Short, json2T(flag.Default, false), flag.Desc)
   265  		} else {
   266  			flagset.Bool(flag.Long, json2T(flag.Default, false), flag.Desc)
   267  		}
   268  	case plugin.FlagType_Int:
   269  		if flag.Short != "" {
   270  			flagset.IntP(flag.Long, flag.Short, json2T(flag.Default, 0), flag.Desc)
   271  		} else {
   272  			flagset.Int(flag.Long, json2T(flag.Default, 0), flag.Desc)
   273  		}
   274  	case plugin.FlagType_String:
   275  		if flag.Short != "" {
   276  			flagset.StringP(flag.Long, flag.Short, flag.Default, flag.Desc)
   277  		} else {
   278  			flagset.String(flag.Long, flag.Default, flag.Desc)
   279  		}
   280  	case plugin.FlagType_List:
   281  		if flag.Short != "" {
   282  			flagset.StringSliceP(flag.Long, flag.Short, json2T(flag.Default, []string{}), flag.Desc)
   283  		} else {
   284  			flagset.StringSlice(flag.Long, json2T(flag.Default, []string{}), flag.Desc)
   285  		}
   286  	case plugin.FlagType_KeyValue:
   287  		if flag.Short != "" {
   288  			flagset.StringToStringP(flag.Long, flag.Short, json2T(flag.Default, map[string]string{}), flag.Desc)
   289  		} else {
   290  			flagset.StringToString(flag.Long, json2T(flag.Default, map[string]string{}), flag.Desc)
   291  		}
   292  	}
   293  
   294  	if flag.Option&plugin.FlagOption_Hidden != 0 {
   295  		flagset.MarkHidden(flag.Long)
   296  	}
   297  	if flag.Option&plugin.FlagOption_Deprecated != 0 {
   298  		flagset.MarkDeprecated(flag.Long, "has been deprecated")
   299  	}
   300  }
   301  
   302  func attachFlags(flagset *pflag.FlagSet, flags []plugin.Flag) {
   303  	for i := range flags {
   304  		attachFlag(flagset, flags[i])
   305  	}
   306  }
   307  
   308  func getFlagValue(flag plugin.Flag, cmd *cobra.Command) *llx.Primitive {
   309  	switch flag.Type {
   310  	case plugin.FlagType_Bool:
   311  		v, err := cmd.Flags().GetBool(flag.Long)
   312  		if err == nil {
   313  			return llx.BoolPrimitive(v)
   314  		}
   315  		log.Warn().Err(err).Msg("failed to get flag " + flag.Long)
   316  	case plugin.FlagType_Int:
   317  		if v, err := cmd.Flags().GetInt(flag.Long); err == nil {
   318  			return llx.IntPrimitive(int64(v))
   319  		}
   320  	case plugin.FlagType_String:
   321  		if v, err := cmd.Flags().GetString(flag.Long); err == nil {
   322  			return llx.StringPrimitive(v)
   323  		}
   324  	case plugin.FlagType_List:
   325  		if v, err := cmd.Flags().GetStringSlice(flag.Long); err == nil {
   326  			return llx.ArrayPrimitiveT(v, llx.StringPrimitive, types.String)
   327  		}
   328  	case plugin.FlagType_KeyValue:
   329  		if v, err := cmd.Flags().GetStringToString(flag.Long); err == nil {
   330  			return llx.MapPrimitiveT(v, llx.StringPrimitive, types.String)
   331  		}
   332  	default:
   333  		log.Warn().Msg("unknown flag type for " + flag.Long)
   334  		return nil
   335  	}
   336  	return nil
   337  }
   338  
   339  func setConnector(provider *plugin.Provider, connector *plugin.Connector, run func(*cobra.Command, *providers.Runtime, *plugin.ParseCLIRes), cmd *cobra.Command) {
   340  	oldRun := cmd.Run
   341  	oldPreRun := cmd.PreRun
   342  
   343  	builtinFlags := genBuiltinFlags(connector.Discovery...)
   344  	allFlags := append(connector.Flags, builtinFlags...)
   345  
   346  	cmd.PreRun = func(cc *cobra.Command, args []string) {
   347  		if oldPreRun != nil {
   348  			oldPreRun(cc, args)
   349  		}
   350  
   351  		// Config options need to be connected to flags before the Run begins.
   352  		// Flags are provided by the connector.
   353  		for i := range allFlags {
   354  			flag := allFlags[i]
   355  			if flag.ConfigEntry == "-" {
   356  				continue
   357  			}
   358  
   359  			flagName := flag.ConfigEntry
   360  			if flagName == "" {
   361  				flagName = flag.Long
   362  			}
   363  
   364  			viper.BindPFlag(flagName, cmd.Flags().Lookup(flag.Long))
   365  		}
   366  	}
   367  
   368  	cmd.Run = func(cc *cobra.Command, args []string) {
   369  		if oldRun != nil {
   370  			oldRun(cc, args)
   371  		}
   372  
   373  		log.Debug().Msg("using provider " + provider.Name + " with connector " + connector.Name)
   374  
   375  		// TODO: replace this hard-coded block. This should be dynamic for all
   376  		// fields that are specified to be passwords with the --ask-field
   377  		// associated with it to make it simple.
   378  		// check if the user used --password without a value
   379  		askPass, err := cc.Flags().GetBool("ask-pass")
   380  		if err == nil && askPass {
   381  			pass, err := components.AskPassword("Enter password: ")
   382  			if err != nil {
   383  				log.Fatal().Err(err).Msg("failed to get password")
   384  			}
   385  			cc.Flags().Set("password", pass)
   386  		}
   387  		// ^^
   388  
   389  		useRecording, err := cc.Flags().GetString("use-recording")
   390  		if err != nil {
   391  			log.Warn().Msg("failed to get flag --recording")
   392  		}
   393  		record, err := cc.Flags().GetString("record")
   394  		if err != nil {
   395  			log.Warn().Msg("failed to get flag --record")
   396  		}
   397  		pretty, err := cc.Flags().GetBool("pretty")
   398  		if err != nil {
   399  			log.Warn().Msg("failed to get flag --pretty")
   400  		}
   401  
   402  		flagVals := map[string]*llx.Primitive{}
   403  		for i := range allFlags {
   404  			flag := allFlags[i]
   405  
   406  			// we skip these because they are coded above
   407  			if _, skip := skipFlags[flag.Long]; skip {
   408  				continue
   409  			}
   410  
   411  			if v := getFlagValue(flag, cmd); v != nil {
   412  				flagVals[flag.Long] = v
   413  			}
   414  		}
   415  
   416  		// TODO: add flag to set timeout and then use RuntimeWithShutdownTimeout
   417  		runtime := providers.Coordinator.NewRuntime()
   418  		if err = providers.SetDefaultRuntime(runtime); err != nil {
   419  			log.Error().Msg(err.Error())
   420  		}
   421  
   422  		autoUpdate := true
   423  		if viper.IsSet("auto_update") {
   424  			autoUpdate = viper.GetBool("auto_update")
   425  		}
   426  
   427  		runtime.AutoUpdate = providers.UpdateProvidersConfig{
   428  			Enabled:         autoUpdate,
   429  			RefreshInterval: 60 * 60,
   430  		}
   431  
   432  		if err := runtime.UseProvider(provider.ID); err != nil {
   433  			providers.Coordinator.Shutdown()
   434  			log.Fatal().Err(err).Msg("failed to start provider " + provider.Name)
   435  		}
   436  
   437  		if record != "" && useRecording != "" {
   438  			log.Fatal().Msg("please only use --record or --use-recording, but not both at the same time")
   439  		}
   440  		recordingPath := record
   441  		if recordingPath == "" {
   442  			recordingPath = useRecording
   443  		}
   444  		doRecord := record != ""
   445  
   446  		recording, err := providers.NewRecording(recordingPath, providers.RecordingOptions{
   447  			DoRecord:        doRecord,
   448  			PrettyPrintJSON: pretty,
   449  		})
   450  		if err != nil {
   451  			log.Fatal().Msg(err.Error())
   452  		}
   453  		runtime.SetRecording(recording)
   454  
   455  		cliRes, err := runtime.Provider.Instance.Plugin.ParseCLI(&plugin.ParseCLIReq{
   456  			Connector: connector.Name,
   457  			Args:      args,
   458  			Flags:     flagVals,
   459  		})
   460  		if err != nil {
   461  			runtime.Close()
   462  			providers.Coordinator.Shutdown()
   463  			log.Fatal().Err(err).Msg("failed to parse cli arguments")
   464  		}
   465  
   466  		if cliRes == nil {
   467  			runtime.Close()
   468  			providers.Coordinator.Shutdown()
   469  			log.Fatal().Msg("failed to process CLI arguments, nothing was returned")
   470  		}
   471  
   472  		run(cc, runtime, cliRes)
   473  		runtime.Close()
   474  		providers.Coordinator.Shutdown()
   475  	}
   476  
   477  	attachFlags(cmd.Flags(), allFlags)
   478  }
   479  
   480  func json2T[T any](s string, empty T) T {
   481  	var res T
   482  	if err := json.Unmarshal([]byte(s), &res); err == nil {
   483  		return res
   484  	}
   485  	return empty
   486  }