github.com/devseccon/trivy@v0.47.1-0.20231123133102-bd902a0bd996/pkg/flag/options.go (about)

     1  package flag
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"os"
     7  	"strings"
     8  	"sync"
     9  	"time"
    10  
    11  	"github.com/spf13/cast"
    12  	"github.com/spf13/cobra"
    13  	"github.com/spf13/pflag"
    14  	"github.com/spf13/viper"
    15  	"golang.org/x/xerrors"
    16  
    17  	"github.com/devseccon/trivy/pkg/fanal/analyzer"
    18  	ftypes "github.com/devseccon/trivy/pkg/fanal/types"
    19  	"github.com/devseccon/trivy/pkg/log"
    20  	"github.com/devseccon/trivy/pkg/result"
    21  	"github.com/devseccon/trivy/pkg/types"
    22  	"github.com/devseccon/trivy/pkg/version"
    23  	xstrings "github.com/devseccon/trivy/pkg/x/strings"
    24  )
    25  
    26  type Flag struct {
    27  	// Name is for CLI flag and environment variable.
    28  	// If this field is empty, it will be available only in config file.
    29  	Name string
    30  
    31  	// ConfigName is a key in config file. It is also used as a key of viper.
    32  	ConfigName string
    33  
    34  	// Shorthand is a shorthand letter.
    35  	Shorthand string
    36  
    37  	// Default is the default value. It must be filled to determine the flag type.
    38  	Default any
    39  
    40  	// Values is a list of allowed values.
    41  	// It currently supports string flags and string slice flags only.
    42  	Values []string
    43  
    44  	// ValueNormalize is a function to normalize the value.
    45  	// It can be used for aliases, etc.
    46  	ValueNormalize func(string) string
    47  
    48  	// Usage explains how to use the flag.
    49  	Usage string
    50  
    51  	// Persistent represents if the flag is persistent
    52  	Persistent bool
    53  
    54  	// Deprecated represents if the flag is deprecated
    55  	Deprecated bool
    56  
    57  	// Aliases represents aliases
    58  	Aliases []Alias
    59  }
    60  
    61  type Alias struct {
    62  	Name       string
    63  	ConfigName string
    64  	Deprecated bool
    65  }
    66  
    67  type FlagGroup interface {
    68  	Name() string
    69  	Flags() []*Flag
    70  }
    71  
    72  type Flags struct {
    73  	AWSFlagGroup           *AWSFlagGroup
    74  	CacheFlagGroup         *CacheFlagGroup
    75  	CloudFlagGroup         *CloudFlagGroup
    76  	DBFlagGroup            *DBFlagGroup
    77  	ImageFlagGroup         *ImageFlagGroup
    78  	K8sFlagGroup           *K8sFlagGroup
    79  	LicenseFlagGroup       *LicenseFlagGroup
    80  	MisconfFlagGroup       *MisconfFlagGroup
    81  	ModuleFlagGroup        *ModuleFlagGroup
    82  	RemoteFlagGroup        *RemoteFlagGroup
    83  	RegistryFlagGroup      *RegistryFlagGroup
    84  	RegoFlagGroup          *RegoFlagGroup
    85  	RepoFlagGroup          *RepoFlagGroup
    86  	ReportFlagGroup        *ReportFlagGroup
    87  	SBOMFlagGroup          *SBOMFlagGroup
    88  	ScanFlagGroup          *ScanFlagGroup
    89  	SecretFlagGroup        *SecretFlagGroup
    90  	VulnerabilityFlagGroup *VulnerabilityFlagGroup
    91  }
    92  
    93  // Options holds all the runtime configuration
    94  type Options struct {
    95  	GlobalOptions
    96  	AWSOptions
    97  	CacheOptions
    98  	CloudOptions
    99  	DBOptions
   100  	ImageOptions
   101  	K8sOptions
   102  	LicenseOptions
   103  	MisconfOptions
   104  	ModuleOptions
   105  	RegistryOptions
   106  	RegoOptions
   107  	RemoteOptions
   108  	RepoOptions
   109  	ReportOptions
   110  	SBOMOptions
   111  	ScanOptions
   112  	SecretOptions
   113  	VulnerabilityOptions
   114  
   115  	// Trivy's version, not populated via CLI flags
   116  	AppVersion string
   117  
   118  	// We don't want to allow disabled analyzers to be passed by users, but it is necessary for internal use.
   119  	DisabledAnalyzers []analyzer.Type
   120  
   121  	// outputWriter is not initialized via the CLI.
   122  	// It is mainly used for testing purposes or by tools that use Trivy as a library.
   123  	outputWriter io.Writer
   124  }
   125  
   126  // Align takes consistency of options
   127  func (o *Options) Align() {
   128  	if o.Format == types.FormatSPDX || o.Format == types.FormatSPDXJSON {
   129  		log.Logger.Info(`"--format spdx" and "--format spdx-json" disable security scanning`)
   130  		o.Scanners = nil
   131  	}
   132  
   133  	// Vulnerability scanning is disabled by default for CycloneDX.
   134  	if o.Format == types.FormatCycloneDX && !viper.IsSet(ScannersFlag.ConfigName) && len(o.K8sOptions.Components) == 0 { // remove K8sOptions.Components validation check when vuln scan is supported for k8s report with cycloneDX
   135  		log.Logger.Info(`"--format cyclonedx" disables security scanning. Specify "--scanners vuln" explicitly if you want to include vulnerabilities in the CycloneDX report.`)
   136  		o.Scanners = nil
   137  	}
   138  
   139  	if o.Format == types.FormatCycloneDX && len(o.K8sOptions.Components) > 0 {
   140  		log.Logger.Info(`"k8s with --format cyclonedx" disable security scanning`)
   141  		o.Scanners = nil
   142  	}
   143  }
   144  
   145  // RegistryOpts returns options for OCI registries
   146  func (o *Options) RegistryOpts() ftypes.RegistryOptions {
   147  	return ftypes.RegistryOptions{
   148  		Credentials:   o.Credentials,
   149  		RegistryToken: o.RegistryToken,
   150  		Insecure:      o.Insecure,
   151  		Platform:      o.Platform,
   152  		AWSRegion:     o.AWSOptions.Region,
   153  	}
   154  }
   155  
   156  // FilterOpts returns options for filtering
   157  func (o *Options) FilterOpts() result.FilterOption {
   158  	return result.FilterOption{
   159  		Severities:         o.Severities,
   160  		IgnoreStatuses:     o.IgnoreStatuses,
   161  		IncludeNonFailures: o.IncludeNonFailures,
   162  		IgnoreFile:         o.IgnoreFile,
   163  		PolicyFile:         o.IgnorePolicy,
   164  		IgnoreLicenses:     o.IgnoredLicenses,
   165  		VEXPath:            o.VEXPath,
   166  	}
   167  }
   168  
   169  // SetOutputWriter sets an output writer.
   170  func (o *Options) SetOutputWriter(w io.Writer) {
   171  	o.outputWriter = w
   172  }
   173  
   174  // OutputWriter returns an output writer.
   175  // If the output file is not specified, it returns os.Stdout.
   176  func (o *Options) OutputWriter() (io.Writer, func(), error) {
   177  	if o.outputWriter != nil {
   178  		return o.outputWriter, func() {}, nil
   179  	}
   180  
   181  	if o.Output != "" {
   182  		f, err := os.Create(o.Output)
   183  		if err != nil {
   184  			return nil, nil, xerrors.Errorf("failed to create output file: %w", err)
   185  		}
   186  		return f, func() { _ = f.Close() }, nil
   187  	}
   188  	return os.Stdout, func() {}, nil
   189  }
   190  
   191  func addFlag(cmd *cobra.Command, flag *Flag) {
   192  	if flag == nil || flag.Name == "" {
   193  		return
   194  	}
   195  	var flags *pflag.FlagSet
   196  	if flag.Persistent {
   197  		flags = cmd.PersistentFlags()
   198  	} else {
   199  		flags = cmd.Flags()
   200  	}
   201  
   202  	switch v := flag.Default.(type) {
   203  	case int:
   204  		flags.IntP(flag.Name, flag.Shorthand, v, flag.Usage)
   205  	case string:
   206  		usage := flag.Usage
   207  		if len(flag.Values) > 0 {
   208  			usage += fmt.Sprintf(" (%s)", strings.Join(flag.Values, ","))
   209  		}
   210  		flags.VarP(newCustomStringValue(v, flag.Values, flag.ValueNormalize), flag.Name, flag.Shorthand, usage)
   211  	case []string:
   212  		usage := flag.Usage
   213  		if len(flag.Values) > 0 {
   214  			usage += fmt.Sprintf(" (%s)", strings.Join(flag.Values, ","))
   215  		}
   216  		flags.VarP(newCustomStringSliceValue(v, flag.Values, flag.ValueNormalize), flag.Name, flag.Shorthand, usage)
   217  	case bool:
   218  		flags.BoolP(flag.Name, flag.Shorthand, v, flag.Usage)
   219  	case time.Duration:
   220  		flags.DurationP(flag.Name, flag.Shorthand, v, flag.Usage)
   221  	case float64:
   222  		flags.Float64P(flag.Name, flag.Shorthand, v, flag.Usage)
   223  	}
   224  
   225  	if flag.Deprecated {
   226  		flags.MarkHidden(flag.Name) // nolint: gosec
   227  	}
   228  }
   229  
   230  func bind(cmd *cobra.Command, flag *Flag) error {
   231  	if flag == nil {
   232  		return nil
   233  	} else if flag.Name == "" {
   234  		// This flag is available only in trivy.yaml
   235  		viper.SetDefault(flag.ConfigName, flag.Default)
   236  		return nil
   237  	}
   238  
   239  	// Bind CLI flags
   240  	f := cmd.Flags().Lookup(flag.Name)
   241  	if f == nil {
   242  		// Lookup local persistent flags
   243  		f = cmd.PersistentFlags().Lookup(flag.Name)
   244  	}
   245  	if err := viper.BindPFlag(flag.ConfigName, f); err != nil {
   246  		return xerrors.Errorf("bind flag error: %w", err)
   247  	}
   248  
   249  	// Bind environmental variable
   250  	if err := bindEnv(flag); err != nil {
   251  		return err
   252  	}
   253  
   254  	return nil
   255  }
   256  
   257  func bindEnv(flag *Flag) error {
   258  	// We don't use viper.AutomaticEnv, so we need to add a prefix manually here.
   259  	envName := strings.ToUpper("trivy_" + strings.ReplaceAll(flag.Name, "-", "_"))
   260  	if err := viper.BindEnv(flag.ConfigName, envName); err != nil {
   261  		return xerrors.Errorf("bind env error: %w", err)
   262  	}
   263  
   264  	// Bind env aliases
   265  	for _, alias := range flag.Aliases {
   266  		envAlias := strings.ToUpper("trivy_" + strings.ReplaceAll(alias.Name, "-", "_"))
   267  		if err := viper.BindEnv(flag.ConfigName, envAlias); err != nil {
   268  			return xerrors.Errorf("bind env error: %w", err)
   269  		}
   270  		if alias.Deprecated {
   271  			if _, ok := os.LookupEnv(envAlias); ok {
   272  				log.Logger.Warnf("'%s' is deprecated. Use '%s' instead.", envAlias, envName)
   273  			}
   274  		}
   275  	}
   276  	return nil
   277  }
   278  
   279  func getString(flag *Flag) string {
   280  	return cast.ToString(getValue(flag))
   281  }
   282  
   283  func getUnderlyingString[T xstrings.String](flag *Flag) T {
   284  	s := getString(flag)
   285  	return T(s)
   286  }
   287  
   288  func getStringSlice(flag *Flag) []string {
   289  	// viper always returns a string for ENV
   290  	// https://github.com/spf13/viper/blob/419fd86e49ef061d0d33f4d1d56d5e2a480df5bb/viper.go#L545-L553
   291  	// and uses strings.Field to separate values (whitespace only)
   292  	// we need to separate env values with ','
   293  	v := cast.ToStringSlice(getValue(flag))
   294  	switch {
   295  	case len(v) == 0: // no strings
   296  		return nil
   297  	case len(v) == 1 && strings.Contains(v[0], ","): // unseparated string
   298  		v = strings.Split(v[0], ",")
   299  	}
   300  	return v
   301  }
   302  
   303  func getUnderlyingStringSlice[T xstrings.String](flag *Flag) []T {
   304  	ss := getStringSlice(flag)
   305  	if len(ss) == 0 {
   306  		return nil
   307  	}
   308  	return xstrings.ToTSlice[T](ss)
   309  }
   310  
   311  func getInt(flag *Flag) int {
   312  	return cast.ToInt(getValue(flag))
   313  }
   314  
   315  func getFloat(flag *Flag) float64 {
   316  	return cast.ToFloat64(getValue(flag))
   317  }
   318  
   319  func getBool(flag *Flag) bool {
   320  	return cast.ToBool(getValue(flag))
   321  }
   322  
   323  func getDuration(flag *Flag) time.Duration {
   324  	return cast.ToDuration(getValue(flag))
   325  }
   326  
   327  func getValue(flag *Flag) any {
   328  	if flag == nil {
   329  		return nil
   330  	}
   331  
   332  	// First, looks for aliases in config file (trivy.yaml).
   333  	// Note that viper.RegisterAlias cannot be used for this purpose.
   334  	var v any
   335  	for _, alias := range flag.Aliases {
   336  		if alias.ConfigName == "" {
   337  			continue
   338  		}
   339  		v = viper.Get(alias.ConfigName)
   340  		if v != nil {
   341  			log.Logger.Warnf("'%s' in config file is deprecated. Use '%s' instead.", alias.ConfigName, flag.ConfigName)
   342  			return v
   343  		}
   344  	}
   345  	return viper.Get(flag.ConfigName)
   346  }
   347  
   348  func (f *Flags) groups() []FlagGroup {
   349  	var groups []FlagGroup
   350  	// This order affects the usage message, so they are sorted by frequency of use.
   351  	if f.ScanFlagGroup != nil {
   352  		groups = append(groups, f.ScanFlagGroup)
   353  	}
   354  	if f.ReportFlagGroup != nil {
   355  		groups = append(groups, f.ReportFlagGroup)
   356  	}
   357  	if f.CacheFlagGroup != nil {
   358  		groups = append(groups, f.CacheFlagGroup)
   359  	}
   360  	if f.DBFlagGroup != nil {
   361  		groups = append(groups, f.DBFlagGroup)
   362  	}
   363  	if f.RegistryFlagGroup != nil {
   364  		groups = append(groups, f.RegistryFlagGroup)
   365  	}
   366  	if f.ImageFlagGroup != nil {
   367  		groups = append(groups, f.ImageFlagGroup)
   368  	}
   369  	if f.SBOMFlagGroup != nil {
   370  		groups = append(groups, f.SBOMFlagGroup)
   371  	}
   372  	if f.VulnerabilityFlagGroup != nil {
   373  		groups = append(groups, f.VulnerabilityFlagGroup)
   374  	}
   375  	if f.MisconfFlagGroup != nil {
   376  		groups = append(groups, f.MisconfFlagGroup)
   377  	}
   378  	if f.ModuleFlagGroup != nil {
   379  		groups = append(groups, f.ModuleFlagGroup)
   380  	}
   381  	if f.SecretFlagGroup != nil {
   382  		groups = append(groups, f.SecretFlagGroup)
   383  	}
   384  	if f.LicenseFlagGroup != nil {
   385  		groups = append(groups, f.LicenseFlagGroup)
   386  	}
   387  	if f.RegoFlagGroup != nil {
   388  		groups = append(groups, f.RegoFlagGroup)
   389  	}
   390  	if f.CloudFlagGroup != nil {
   391  		groups = append(groups, f.CloudFlagGroup)
   392  	}
   393  	if f.AWSFlagGroup != nil {
   394  		groups = append(groups, f.AWSFlagGroup)
   395  	}
   396  	if f.K8sFlagGroup != nil {
   397  		groups = append(groups, f.K8sFlagGroup)
   398  	}
   399  	if f.RemoteFlagGroup != nil {
   400  		groups = append(groups, f.RemoteFlagGroup)
   401  	}
   402  	if f.RepoFlagGroup != nil {
   403  		groups = append(groups, f.RepoFlagGroup)
   404  	}
   405  	return groups
   406  }
   407  
   408  func (f *Flags) AddFlags(cmd *cobra.Command) {
   409  	aliases := make(flagAliases)
   410  	for _, group := range f.groups() {
   411  		for _, flag := range group.Flags() {
   412  			addFlag(cmd, flag)
   413  
   414  			// Register flag aliases
   415  			aliases.Add(flag)
   416  		}
   417  	}
   418  
   419  	cmd.Flags().SetNormalizeFunc(aliases.NormalizeFunc())
   420  }
   421  
   422  func (f *Flags) Usages(cmd *cobra.Command) string {
   423  	var usages string
   424  	for _, group := range f.groups() {
   425  
   426  		flags := pflag.NewFlagSet(cmd.Name(), pflag.ContinueOnError)
   427  		lflags := cmd.LocalFlags()
   428  		for _, flag := range group.Flags() {
   429  			if flag == nil || flag.Name == "" {
   430  				continue
   431  			}
   432  			flags.AddFlag(lflags.Lookup(flag.Name))
   433  		}
   434  		if !flags.HasAvailableFlags() {
   435  			continue
   436  		}
   437  
   438  		usages += fmt.Sprintf("%s Flags\n", group.Name())
   439  		usages += flags.FlagUsages() + "\n"
   440  	}
   441  	return strings.TrimSpace(usages)
   442  }
   443  
   444  func (f *Flags) Bind(cmd *cobra.Command) error {
   445  	for _, group := range f.groups() {
   446  		if group == nil {
   447  			continue
   448  		}
   449  		for _, flag := range group.Flags() {
   450  			if err := bind(cmd, flag); err != nil {
   451  				return xerrors.Errorf("flag groups: %w", err)
   452  			}
   453  		}
   454  	}
   455  	return nil
   456  }
   457  
   458  // nolint: gocyclo
   459  func (f *Flags) ToOptions(args []string, globalFlags *GlobalFlagGroup) (Options, error) {
   460  	var err error
   461  	opts := Options{
   462  		AppVersion:    version.AppVersion(),
   463  		GlobalOptions: globalFlags.ToOptions(),
   464  	}
   465  
   466  	if f.AWSFlagGroup != nil {
   467  		opts.AWSOptions = f.AWSFlagGroup.ToOptions()
   468  	}
   469  
   470  	if f.CloudFlagGroup != nil {
   471  		opts.CloudOptions = f.CloudFlagGroup.ToOptions()
   472  	}
   473  
   474  	if f.CacheFlagGroup != nil {
   475  		opts.CacheOptions, err = f.CacheFlagGroup.ToOptions()
   476  		if err != nil {
   477  			return Options{}, xerrors.Errorf("cache flag error: %w", err)
   478  		}
   479  	}
   480  
   481  	if f.DBFlagGroup != nil {
   482  		opts.DBOptions, err = f.DBFlagGroup.ToOptions()
   483  		if err != nil {
   484  			return Options{}, xerrors.Errorf("flag error: %w", err)
   485  		}
   486  	}
   487  
   488  	if f.ImageFlagGroup != nil {
   489  		opts.ImageOptions, err = f.ImageFlagGroup.ToOptions()
   490  		if err != nil {
   491  			return Options{}, xerrors.Errorf("image flag error: %w", err)
   492  		}
   493  	}
   494  
   495  	if f.K8sFlagGroup != nil {
   496  		opts.K8sOptions, err = f.K8sFlagGroup.ToOptions()
   497  		if err != nil {
   498  			return Options{}, xerrors.Errorf("k8s flag error: %w", err)
   499  		}
   500  	}
   501  
   502  	if f.LicenseFlagGroup != nil {
   503  		opts.LicenseOptions = f.LicenseFlagGroup.ToOptions()
   504  	}
   505  
   506  	if f.MisconfFlagGroup != nil {
   507  		opts.MisconfOptions, err = f.MisconfFlagGroup.ToOptions()
   508  		if err != nil {
   509  			return Options{}, xerrors.Errorf("misconfiguration flag error: %w", err)
   510  		}
   511  	}
   512  
   513  	if f.ModuleFlagGroup != nil {
   514  		opts.ModuleOptions = f.ModuleFlagGroup.ToOptions()
   515  	}
   516  
   517  	if f.RegoFlagGroup != nil {
   518  		opts.RegoOptions, err = f.RegoFlagGroup.ToOptions()
   519  		if err != nil {
   520  			return Options{}, xerrors.Errorf("rego flag error: %w", err)
   521  		}
   522  	}
   523  
   524  	if f.RemoteFlagGroup != nil {
   525  		opts.RemoteOptions = f.RemoteFlagGroup.ToOptions()
   526  	}
   527  
   528  	if f.RegistryFlagGroup != nil {
   529  		opts.RegistryOptions, err = f.RegistryFlagGroup.ToOptions()
   530  		if err != nil {
   531  			return Options{}, xerrors.Errorf("registry flag error: %w", err)
   532  		}
   533  	}
   534  
   535  	if f.RepoFlagGroup != nil {
   536  		opts.RepoOptions = f.RepoFlagGroup.ToOptions()
   537  	}
   538  
   539  	if f.ReportFlagGroup != nil {
   540  		opts.ReportOptions, err = f.ReportFlagGroup.ToOptions()
   541  		if err != nil {
   542  			return Options{}, xerrors.Errorf("report flag error: %w", err)
   543  		}
   544  	}
   545  
   546  	if f.SBOMFlagGroup != nil {
   547  		opts.SBOMOptions, err = f.SBOMFlagGroup.ToOptions()
   548  		if err != nil {
   549  			return Options{}, xerrors.Errorf("sbom flag error: %w", err)
   550  		}
   551  	}
   552  
   553  	if f.ScanFlagGroup != nil {
   554  		opts.ScanOptions, err = f.ScanFlagGroup.ToOptions(args)
   555  		if err != nil {
   556  			return Options{}, xerrors.Errorf("scan flag error: %w", err)
   557  		}
   558  	}
   559  
   560  	if f.SecretFlagGroup != nil {
   561  		opts.SecretOptions = f.SecretFlagGroup.ToOptions()
   562  	}
   563  
   564  	if f.VulnerabilityFlagGroup != nil {
   565  		opts.VulnerabilityOptions = f.VulnerabilityFlagGroup.ToOptions()
   566  	}
   567  
   568  	opts.Align()
   569  
   570  	return opts, nil
   571  }
   572  
   573  type flagAlias struct {
   574  	formalName string
   575  	deprecated bool
   576  	once       sync.Once
   577  }
   578  
   579  // flagAliases have aliases for CLI flags
   580  type flagAliases map[string]*flagAlias
   581  
   582  func (a flagAliases) Add(flag *Flag) {
   583  	if flag == nil {
   584  		return
   585  	}
   586  	for _, alias := range flag.Aliases {
   587  		a[alias.Name] = &flagAlias{
   588  			formalName: flag.Name,
   589  			deprecated: alias.Deprecated,
   590  		}
   591  	}
   592  }
   593  
   594  func (a flagAliases) NormalizeFunc() func(*pflag.FlagSet, string) pflag.NormalizedName {
   595  	return func(_ *pflag.FlagSet, name string) pflag.NormalizedName {
   596  		if alias, ok := a[name]; ok {
   597  			if alias.deprecated {
   598  				// NormalizeFunc is called several times
   599  				alias.once.Do(func() {
   600  					log.Logger.Warnf("'--%s' is deprecated. Use '--%s' instead.", name, alias.formalName)
   601  				})
   602  			}
   603  			name = alias.formalName
   604  		}
   605  		return pflag.NormalizedName(name)
   606  	}
   607  }