github.com/criteo/command-launcher@v0.0.0-20230407142452-fb616f546e98/internal/frontend/default-frontend.go (about)

     1  package frontend
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"os"
     7  	"strconv"
     8  	"strings"
     9  
    10  	"github.com/criteo/command-launcher/cmd/consent"
    11  	"github.com/criteo/command-launcher/internal/backend"
    12  	"github.com/criteo/command-launcher/internal/command"
    13  	"github.com/criteo/command-launcher/internal/config"
    14  	"github.com/criteo/command-launcher/internal/console"
    15  	"github.com/criteo/command-launcher/internal/context"
    16  	"github.com/criteo/command-launcher/internal/helper"
    17  
    18  	log "github.com/sirupsen/logrus"
    19  
    20  	"github.com/spf13/cobra"
    21  	"github.com/spf13/pflag"
    22  	"github.com/spf13/viper"
    23  )
    24  
    25  const (
    26  	EXECUTABLE_NOT_DEFINED = "Executable not defined"
    27  )
    28  
    29  var (
    30  	RootExitCode = 0
    31  )
    32  
    33  type defaultFrontend struct {
    34  	rootCmd *cobra.Command
    35  
    36  	appCtx  context.LauncherContext
    37  	backend backend.Backend
    38  
    39  	groupCmds      map[string]*cobra.Command
    40  	executableCmds map[string]*cobra.Command
    41  }
    42  
    43  func NewDefaultFrontend(appCtx context.LauncherContext, rootCmd *cobra.Command, backend backend.Backend) Frontend {
    44  	frontend := &defaultFrontend{
    45  		appCtx:  appCtx,
    46  		rootCmd: rootCmd,
    47  		backend: backend,
    48  
    49  		groupCmds:      make(map[string]*cobra.Command),
    50  		executableCmds: make(map[string]*cobra.Command),
    51  	}
    52  	return frontend
    53  }
    54  
    55  func (self *defaultFrontend) AddUserCommands() {
    56  	self.addGroupCommands()
    57  	self.addExecutableCommands()
    58  }
    59  
    60  func (self *defaultFrontend) addGroupCommands() {
    61  	groups := self.backend.GroupCommands()
    62  	for _, v := range groups {
    63  		group := v.RuntimeGroup()
    64  		name := v.RuntimeName()
    65  		usage := strings.TrimSpace(fmt.Sprintf("%s %s",
    66  			v.RuntimeName(),
    67  			strings.TrimSpace(strings.Trim(v.ArgsUsage(), v.RuntimeName())),
    68  		))
    69  		requiredFlags := v.RequiredFlags()
    70  		requestedResources := v.RequestedResources()
    71  		flags := v.Flags()
    72  		exclusiveFlags := v.ExclusiveFlags()
    73  		groupFlags := v.GroupFlags()
    74  		cmd := &cobra.Command{
    75  			DisableFlagParsing: true, // not enable the checkFlags feature for group command for now
    76  			Use:                usage,
    77  			Example:            formatExamples(v.Examples()),
    78  			Short:              v.ShortDescription(),
    79  			Long:               v.LongDescription(),
    80  			Run: func(cmd *cobra.Command, args []string) {
    81  				consents, err := consent.GetConsents(group, name, requestedResources, viper.GetBool(config.ENABLE_USER_CONSENT_KEY))
    82  				if err != nil {
    83  					log.Warnf("failed to get user consent: %v", err)
    84  				}
    85  				exitCode, err := self.executeCommand(group, name, args, []string{}, consents)
    86  				if err != nil && err.Error() == EXECUTABLE_NOT_DEFINED {
    87  					cmd.Help()
    88  				}
    89  				RootExitCode = exitCode
    90  			},
    91  		}
    92  		// legacy flag definition ("requiredFlags")
    93  		// deprecated
    94  		for _, flag := range requiredFlags {
    95  			addFlagToCmd(cmd, flag)
    96  		}
    97  		// new ways to handle flags, first arguments "checkFlags" is always false for group command
    98  		self.processFlags(false, group, name, cmd, flags, exclusiveFlags, groupFlags)
    99  
   100  		self.groupCmds[v.RuntimeName()] = cmd
   101  		self.rootCmd.AddCommand(cmd)
   102  	}
   103  }
   104  
   105  func (self *defaultFrontend) addExecutableCommands() {
   106  	executables := self.backend.ExecutableCommands()
   107  	for _, v := range executables {
   108  		group := v.RuntimeGroup()
   109  		name := v.RuntimeName()
   110  		usage := strings.TrimSpace(fmt.Sprintf("%s %s",
   111  			v.RuntimeName(),
   112  			strings.TrimSpace(strings.Trim(v.ArgsUsage(), v.RuntimeName())),
   113  		))
   114  		requiredFlags := v.RequiredFlags()
   115  		validArgs := v.ValidArgs()
   116  		validArgsCmd := v.ValidArgsCmd()
   117  		checkFlags := v.CheckFlags()
   118  		requestedResources := v.RequestedResources()
   119  		flags := v.Flags()
   120  		exclusiveFlags := v.ExclusiveFlags()
   121  		groupFlags := v.GroupFlags()
   122  		cmd := &cobra.Command{
   123  			DisableFlagParsing: !checkFlags,
   124  			Use:                usage,
   125  			Example:            formatExamples(v.Examples()),
   126  			Short:              v.ShortDescription(),
   127  			Long:               v.LongDescription(),
   128  			Run: func(c *cobra.Command, args []string) {
   129  				consents, err := consent.GetConsents(group, name, requestedResources, viper.GetBool(config.ENABLE_USER_CONSENT_KEY))
   130  				if err != nil {
   131  					log.Warnf("failed to get user consent: %v", err)
   132  				}
   133  
   134  				envVars, originalArgs, code, shouldQuit := self.parseArgsToEnvVars(c, args, checkFlags)
   135  				if shouldQuit {
   136  					RootExitCode = code
   137  					return
   138  				}
   139  
   140  				if exitCode, err := self.executeCommand(group, name, originalArgs, envVars, consents); err != nil {
   141  					RootExitCode = exitCode
   142  				}
   143  			},
   144  		}
   145  
   146  		// legacy flag definition ("requiredFlags")
   147  		// deprecated
   148  		for _, flag := range requiredFlags {
   149  			addFlagToCmd(cmd, flag)
   150  		}
   151  		// new ways to handle flags
   152  		self.processFlags(checkFlags, group, name, cmd, flags, exclusiveFlags, groupFlags)
   153  
   154  		cmd.ValidArgsFunction = func(c *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
   155  			if len(validArgsCmd) > 0 {
   156  	      var originalArgs = args
   157          if checkFlags {
   158  					c.LocalFlags().VisitAll(func(flag *pflag.Flag) {
   159  						switch flag.Value.Type() {
   160  						case "bool":
   161  							if flag.Value.String() == "true" {
   162  								originalArgs = append(originalArgs, fmt.Sprintf("--%s", flag.Name))
   163  							}
   164  						default:
   165  							if flag.Value.String() != "" {
   166  								originalArgs = append(originalArgs, fmt.Sprintf("--%s", flag.Name), flag.Value.String())
   167  							}
   168  						}
   169  					})
   170  				}
   171  				output, err := self.executeValidArgsOfCommand(group, name, originalArgs)
   172  				if err != nil {
   173  					return []string{}, cobra.ShellCompDirectiveNoFileComp
   174  				}
   175  				parts := strings.Split(output, "\n")
   176  				if len(parts) > 0 {
   177  					if strings.HasPrefix(parts[0], "#") { // skip the first control line, for further controls
   178  						// the first line starting with # is the control line, it controls the completion behavior when the return body is empty
   179  						shellDirective := cobra.ShellCompDirectiveNoFileComp
   180  						switch strings.TrimSpace(strings.TrimLeft(parts[0], "#")) {
   181  						case "dir-completion-only":
   182  							shellDirective = cobra.ShellCompDirectiveFilterDirs
   183  						case "default":
   184  							shellDirective = cobra.ShellCompDirectiveDefault
   185  						case "no-file-completion":
   186  							shellDirective = cobra.ShellCompDirectiveNoFileComp
   187  						}
   188  						return parts[1:], shellDirective
   189  					}
   190  					return parts, cobra.ShellCompDirectiveNoFileComp
   191  				}
   192  			}
   193  			if len(validArgs) > 0 {
   194  				return validArgs, cobra.ShellCompDirectiveNoFileComp
   195  			}
   196  			return []string{}, cobra.ShellCompDirectiveDefault
   197  		}
   198  
   199  		if v.RuntimeGroup() == "" {
   200  			self.rootCmd.AddCommand(cmd)
   201  		} else {
   202  			if group, exists := self.groupCmds[v.RuntimeGroup()]; exists {
   203  				group.AddCommand(cmd)
   204  			} else {
   205  				log.Errorf("cannot install cmd %s in group %s: group not found", v.Name(), v.Group())
   206  			}
   207  		}
   208  
   209  	}
   210  }
   211  
   212  // parse args and inject environment vars
   213  // if checkFlags is disabled, it simply returns the empty variables, and the args input
   214  // otherwise, return the environment vars, original args, exit code, and if we should exit
   215  func (self *defaultFrontend) parseArgsToEnvVars(c *cobra.Command, args []string, checkFlags bool) ([]string, []string, int, bool) {
   216  	var envVars []string = []string{}
   217  	var envTable map[string]string = map[string]string{}
   218  	var originalArgs = args
   219  
   220  	log.Debugf("checkFlags=%t", checkFlags)
   221  	if checkFlags {
   222  		var err error = nil
   223  		envVarPrefix := strings.ToUpper(self.appCtx.AppName())
   224  		envVars, envTable, originalArgs, err = parseCmdArgsToEnv(c, args, envVarPrefix)
   225  		if err != nil {
   226  			console.Error("Failed to parse arguments: %v", err)
   227  			// set exit code to 1, and should quit
   228  			return envVars, originalArgs, 1, true
   229  		}
   230  		if h, exist := envTable[fmt.Sprintf("%s_FLAG_HELP", envVarPrefix)]; exist && h == "true" {
   231  			c.Help()
   232  			// show help and should quit
   233  			return envVars, originalArgs, 0, true
   234  		}
   235  	}
   236  	log.Debugf("flag & args environments: %v", envVars)
   237  
   238  	return envVars, originalArgs, 0, false
   239  }
   240  
   241  func formatExamples(examples []command.ExampleEntry) string {
   242  	if examples == nil || len(examples) == 0 {
   243  		return ""
   244  	}
   245  
   246  	output := []string{}
   247  
   248  	for _, v := range examples {
   249  		output = append(output, fmt.Sprintf(`  # %s
   250    %s
   251  `, v.Scenario, v.Command))
   252  	}
   253  
   254  	return strings.Join(output, "\n")
   255  }
   256  
   257  func (self *defaultFrontend) getExecutableCommand(group, name string) (command.Command, error) {
   258  	iCmd, err := self.backend.FindCommand(group, name)
   259  	return iCmd, err
   260  }
   261  
   262  // execute a cdt command
   263  func (self *defaultFrontend) executeCommand(group, name string, args []string, initialEnvCtx []string, consent []string) (int, error) {
   264  	iCmd, err := self.getExecutableCommand(group, name)
   265  	if err != nil {
   266  		return 1, err
   267  	}
   268  	if iCmd.Executable() == "" {
   269  		return 1, errors.New(EXECUTABLE_NOT_DEFINED)
   270  	}
   271  
   272  	envCtx := self.getCmdEnvContext(initialEnvCtx, consent)
   273  	envCtx = append(envCtx, fmt.Sprintf("%s=%s", self.appCtx.CmdPackageDirEnvVar(), iCmd.PackageDir()))
   274  
   275  	exitCode, err := iCmd.Execute(envCtx, args...)
   276  	if err != nil {
   277  		return exitCode, err
   278  	}
   279  
   280  	return exitCode, nil
   281  }
   282  
   283  // execute the valid args command of the cdt command
   284  func (self *defaultFrontend) executeValidArgsOfCommand(group, name string, args []string) (string, error) {
   285  	iCmd, err := self.getExecutableCommand(group, name)
   286  	if err != nil {
   287  		return "", err
   288  	}
   289  
   290  	envCtx := self.getCmdEnvContext([]string{}, []string{})
   291  
   292  	_, output, err := iCmd.ExecuteValidArgsCmd(envCtx, args...)
   293  	if err != nil {
   294  		return "", err
   295  	}
   296  
   297  	return output, nil
   298  }
   299  
   300  // execute the flag values command of the cdt command
   301  func (self *defaultFrontend) executeFlagValuesOfCommand(group, name string, flagCmd []string, args []string) (string, error) {
   302  	iCmd, err := self.getExecutableCommand(group, name)
   303  	if err != nil {
   304  		return "", err
   305  	}
   306  
   307  	envCtx := self.getCmdEnvContext([]string{}, []string{})
   308  
   309  	_, output, err := iCmd.ExecuteFlagValuesCmd(envCtx, flagCmd, args...)
   310  	if err != nil {
   311  		return "", err
   312  	}
   313  
   314  	return output, nil
   315  }
   316  
   317  func addFlagToCmd(cmd *cobra.Command, flag string) {
   318  	flagName, flagShort, flagDesc, flagType, defaultValue := parseFlagDefinition(flag)
   319  	switch flagType {
   320  	case "bool":
   321  		// always use false as the default for the bool type
   322  		cmd.Flags().BoolP(flagName, flagShort, false, flagDesc)
   323  	default:
   324  		cmd.Flags().StringP(flagName, flagShort, defaultValue, flagDesc)
   325  	}
   326  }
   327  
   328  func (self *defaultFrontend) processFlags(checkFlags bool, cmdGroup, cmdName string, cmd *cobra.Command, flags []command.Flag, exclusive [][]string, group [][]string) {
   329  	for _, flag := range flags {
   330  		switch flag.Type() {
   331  		case "bool":
   332  			defaultV, err := strconv.ParseBool(flag.Default())
   333  			if err != nil {
   334  				defaultV = false
   335  			}
   336  			cmd.Flags().BoolP(flag.Name(), flag.ShortName(), defaultV, flag.Description())
   337  		default:
   338  			cmd.Flags().StringP(flag.Name(), flag.ShortName(), flag.Default(), flag.Description())
   339  		}
   340  
   341  		if flag.Required() {
   342  			cmd.MarkFlagRequired(flag.Name())
   343  		}
   344  
   345  		// register auto completion for flag values
   346  		if len(flag.Values()) > 0 {
   347  			// static list
   348  			values := flag.Values()
   349  			cmd.RegisterFlagCompletionFunc(flag.Name(), func(c *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
   350  				return values, cobra.ShellCompDirectiveNoFileComp
   351  			})
   352  		} else if len(flag.ValuesCmd()) > 0 {
   353  			// dynamic list
   354  			valuesCmd := flag.ValuesCmd()
   355  			cmd.RegisterFlagCompletionFunc(flag.Name(), func(c *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
   356  				originalArgs := args
   357  				// when checkFlags is true, we need to recover the original arguments
   358  				if checkFlags {
   359  					c.LocalFlags().VisitAll(func(flag *pflag.Flag) {
   360  						switch flag.Value.Type() {
   361  						case "bool":
   362  							if flag.Value.String() == "true" {
   363  								originalArgs = append(originalArgs, fmt.Sprintf("--%s", flag.Name))
   364  							}
   365  						default:
   366  							if flag.Value.String() != "" {
   367  								originalArgs = append(originalArgs, fmt.Sprintf("--%s", flag.Name), flag.Value.String())
   368  							}
   369  						}
   370  					})
   371  				}
   372  				output, err := self.executeFlagValuesOfCommand(cmdGroup, cmdName, valuesCmd, originalArgs)
   373  				if err != nil {
   374  					return []string{}, cobra.ShellCompDirectiveDefault
   375  				}
   376  				return strings.Split(output, "\n"), cobra.ShellCompDirectiveNoFileComp
   377  			})
   378  		}
   379  	}
   380  
   381  	for _, ex := range exclusive {
   382  		exist := true
   383  		for _, f := range ex {
   384  			if cmd.Flags().Lookup(f) == nil {
   385  				exist = false
   386  			}
   387  		}
   388  		if !exist {
   389  			continue
   390  		}
   391  		cmd.MarkFlagsMutuallyExclusive(ex...)
   392  	}
   393  	for _, ex := range group {
   394  		exist := true
   395  		for _, f := range ex {
   396  			if cmd.Flags().Lookup(f) == nil {
   397  				exist = false
   398  			}
   399  		}
   400  		if !exist {
   401  			continue
   402  		}
   403  		cmd.MarkFlagsRequiredTogether(ex...)
   404  	}
   405  }
   406  
   407  func parseFlagDefinition(line string) (string, string, string, string, string) {
   408  	flagParts := strings.Split(line, "\t")
   409  	name := strings.TrimSpace(flagParts[0])
   410  	short := ""
   411  	description := ""
   412  	flagType := "string"
   413  	defaultValue := ""
   414  	if len(flagParts) == 2 {
   415  		description = strings.TrimSpace(flagParts[1])
   416  	}
   417  	if len(flagParts) > 2 {
   418  		short = strings.TrimSpace(flagParts[1])
   419  		description = strings.TrimSpace(flagParts[2])
   420  	}
   421  	if len(flagParts) > 3 {
   422  		flagType = strings.TrimSpace(flagParts[3])
   423  	}
   424  	if len(flagParts) > 4 {
   425  		defaultValue = strings.TrimSpace(flagParts[4])
   426  	}
   427  
   428  	return name, short, description, flagType, defaultValue
   429  }
   430  
   431  func (self *defaultFrontend) getCmdEnvContext(envVars []string, consents []string) []string {
   432  	vars := append([]string{}, envVars...)
   433  
   434  	for _, item := range consents {
   435  		switch item {
   436  		case consent.USERNAME:
   437  			username, err := helper.GetUsername()
   438  			if err != nil {
   439  				username = ""
   440  			}
   441  			if username != "" {
   442  				vars = append(vars, fmt.Sprintf("%s=%s", self.appCtx.UsernameEnvVar(), username))
   443  			}
   444  		case consent.PASSWORD:
   445  			password, err := helper.GetPassword()
   446  			if err != nil {
   447  				password = ""
   448  			}
   449  			if password != "" {
   450  				vars = append(vars, fmt.Sprintf("%s=%s", self.appCtx.PasswordEnvVar(), password))
   451  			}
   452  		case consent.AUTH_TOKEN:
   453  			token, err := helper.GetAuthToken()
   454  			if err != nil {
   455  				token = ""
   456  			}
   457  			if token != "" {
   458  				vars = append(vars, fmt.Sprintf("%s=%s", self.appCtx.AuthTokenEnvVar(), token))
   459  			}
   460  		case consent.LOG_LEVEL:
   461  			// append log level from configuration
   462  			logLevel := viper.GetString(config.LOG_LEVEL_KEY)
   463  			vars = append(vars, fmt.Sprintf("%s=%s",
   464  				self.appCtx.LogLevelEnvVar(),
   465  				logLevel,
   466  			))
   467  		case consent.DEBUG_FLAGS:
   468  			// append debug flags from configuration
   469  			debugFlags := os.Getenv(self.appCtx.DebugFlagsEnvVar())
   470  			vars = append(vars, fmt.Sprintf("%s=%s,%s",
   471  				self.appCtx.DebugFlagsEnvVar(),
   472  				debugFlags,
   473  				viper.GetString(config.DEBUG_FLAGS_KEY),
   474  			))
   475  		}
   476  	}
   477  
   478  	// Enable variable with prefix [binary_name] and COLA
   479  	// TODO: remove it when in version 1.8 all variables are migrated to COLA prefix
   480  	outputVars := []string{}
   481  	for _, v := range vars {
   482  		prefix := fmt.Sprintf("%s_", strings.ToUpper(self.appCtx.AppName()))
   483  		if strings.HasPrefix(v, prefix) && prefix != "COLA_" {
   484  			outputVars = append(outputVars, strings.Replace(v, prefix, "COLA_", 1))
   485  		}
   486  		outputVars = append(outputVars, v)
   487  	}
   488  
   489  	return outputVars
   490  }
   491  
   492  // return environment variable list, env variable table, original args including flags
   493  func parseCmdArgsToEnv(c *cobra.Command, args []string, envVarPrefix string) ([]string, map[string]string, []string, error) {
   494  	envVars := []string{}
   495  	envTable := map[string]string{}
   496  	originalArgs := []string{}
   497  	if err := c.LocalFlags().Parse(args); err != nil {
   498  		return envVars, envTable, args, err
   499  	}
   500  	c.LocalFlags().VisitAll(func(flag *pflag.Flag) {
   501  		n := strings.ReplaceAll(strings.ToUpper(flag.Name), "-", "_")
   502  		v := flag.Value.String()
   503  		k := fmt.Sprintf("%s_FLAG_%s", envVarPrefix, n)
   504  		envVars = append(envVars,
   505  			fmt.Sprintf(
   506  				"%s=%s",
   507  				k, v,
   508  			),
   509  		)
   510  		envTable[k] = v
   511  
   512  		switch flag.Value.Type() {
   513  		case "bool":
   514  			if flag.Value.String() == "true" {
   515  				originalArgs = append(originalArgs, fmt.Sprintf("--%s", flag.Name))
   516  			}
   517  		default:
   518  			if flag.Value.String() != "" {
   519  				originalArgs = append(originalArgs, fmt.Sprintf("--%s", flag.Name), flag.Value.String())
   520  			}
   521  		}
   522  
   523  	})
   524  	for idx, arg := range c.LocalFlags().Args() {
   525  		k := fmt.Sprintf("%s_ARG_%s", envVarPrefix, strconv.Itoa(idx+1))
   526  		envVars = append(envVars, fmt.Sprintf("%s=%s", k, arg))
   527  		envTable[k] = arg
   528  	}
   529  	// new variable for arg number
   530  	nargs_k := fmt.Sprintf("%s_NARGS", envVarPrefix)
   531  	envTable[nargs_k] = strconv.Itoa(len(c.LocalFlags().Args()))
   532  	envVars = append(envVars, fmt.Sprintf("%s=%s", nargs_k, envTable[nargs_k]))
   533  
   534  	// reconstruct the original command args including flags
   535  	parsedArgs := c.LocalFlags().Args()
   536  	originalArgs = append(originalArgs, parsedArgs...)
   537  
   538  	return envVars, envTable, originalArgs, nil
   539  }