github.com/devseccon/trivy@v0.47.1-0.20231123133102-bd902a0bd996/pkg/flag/value.go (about)

     1  package flag
     2  
     3  import (
     4  	"strings"
     5  
     6  	"github.com/samber/lo"
     7  	"golang.org/x/exp/slices"
     8  	"golang.org/x/xerrors"
     9  )
    10  
    11  type ValueNormalizeFunc func(string) string
    12  
    13  // -- string Value
    14  type customStringValue struct {
    15  	value     *string
    16  	allowed   []string
    17  	normalize ValueNormalizeFunc
    18  }
    19  
    20  func newCustomStringValue(val string, allowed []string, fn ValueNormalizeFunc) *customStringValue {
    21  	return &customStringValue{
    22  		value:     &val,
    23  		allowed:   allowed,
    24  		normalize: fn,
    25  	}
    26  }
    27  
    28  func (s *customStringValue) Set(val string) error {
    29  	if s.normalize != nil {
    30  		val = s.normalize(val)
    31  	}
    32  	if len(s.allowed) > 0 && !slices.Contains(s.allowed, val) {
    33  		return xerrors.Errorf("must be one of %q", s.allowed)
    34  	}
    35  	s.value = &val
    36  	return nil
    37  }
    38  func (s *customStringValue) Type() string {
    39  	return "string"
    40  }
    41  
    42  func (s *customStringValue) String() string { return *s.value }
    43  
    44  // -- stringSlice Value
    45  type customStringSliceValue struct {
    46  	value     *[]string
    47  	allowed   []string
    48  	normalize ValueNormalizeFunc
    49  	changed   bool
    50  }
    51  
    52  func newCustomStringSliceValue(val, allowed []string, fn ValueNormalizeFunc) *customStringSliceValue {
    53  	return &customStringSliceValue{
    54  		value:     &val,
    55  		allowed:   allowed,
    56  		normalize: fn,
    57  	}
    58  }
    59  
    60  func (s *customStringSliceValue) Set(val string) error {
    61  	values := strings.Split(val, ",")
    62  	if s.normalize != nil {
    63  		values = lo.Map(values, func(item string, _ int) string { return s.normalize(item) })
    64  	}
    65  	for _, v := range values {
    66  		if len(s.allowed) > 0 && !slices.Contains(s.allowed, v) {
    67  			return xerrors.Errorf("must be one of %q", s.allowed)
    68  		}
    69  	}
    70  	if !s.changed {
    71  		*s.value = values
    72  	} else {
    73  		*s.value = append(*s.value, values...)
    74  	}
    75  	s.changed = true
    76  	return nil
    77  }
    78  
    79  func (s *customStringSliceValue) Type() string {
    80  	return "stringSlice"
    81  }
    82  
    83  func (s *customStringSliceValue) String() string {
    84  	if len(*s.value) == 0 {
    85  		// "[]" is not recognized as a zero value
    86  		// cf. https://github.com/spf13/pflag/blob/d5e0c0615acee7028e1e2740a11102313be88de1/flag.go#L553-L565
    87  		return ""
    88  	}
    89  	return "[" + strings.Join(*s.value, ",") + "]"
    90  }
    91  
    92  func (s *customStringSliceValue) Append(val string) error {
    93  	s.changed = true
    94  	return s.Set(val)
    95  }
    96  
    97  func (s *customStringSliceValue) Replace(val []string) error {
    98  	*s.value = val
    99  	return nil
   100  }
   101  
   102  func (s *customStringSliceValue) GetSlice() []string {
   103  	return *s.value
   104  }