github.com/etecs-ru/ristretto@v0.9.1/z/flags.go (about)

     1  package z
     2  
     3  import (
     4  	"fmt"
     5  	"log"
     6  	"os"
     7  	"os/user"
     8  	"path/filepath"
     9  	"sort"
    10  	"strconv"
    11  	"strings"
    12  	"time"
    13  )
    14  
    15  // SuperFlagHelp makes it really easy to generate command line `--help` output for a SuperFlag. For
    16  // example:
    17  //
    18  //	const flagDefaults = `enabled=true; path=some/path;`
    19  //
    20  //	var help string = z.NewSuperFlagHelp(flagDefaults).
    21  //		Flag("enabled", "Turns on <something>.").
    22  //		Flag("path", "The path to <something>.").
    23  //		Flag("another", "Not present in defaults, but still included.").
    24  //		String()
    25  //
    26  // The `help` string would then contain:
    27  //
    28  //	enabled=true; Turns on <something>.
    29  //	path=some/path; The path to <something>.
    30  //	another=; Not present in defaults, but still included.
    31  //
    32  // All flags are sorted alphabetically for consistent `--help` output. Flags with default values are
    33  // placed at the top, and everything else goes under.
    34  type SuperFlagHelp struct {
    35  	defaults *SuperFlag
    36  	flags    map[string]string
    37  	head     string
    38  }
    39  
    40  func NewSuperFlagHelp(defaults string) *SuperFlagHelp {
    41  	return &SuperFlagHelp{
    42  		defaults: NewSuperFlag(defaults),
    43  		flags:    make(map[string]string, 0),
    44  	}
    45  }
    46  
    47  func (h *SuperFlagHelp) Head(head string) *SuperFlagHelp {
    48  	h.head = head
    49  	return h
    50  }
    51  
    52  func (h *SuperFlagHelp) Flag(name, description string) *SuperFlagHelp {
    53  	h.flags[name] = description
    54  	return h
    55  }
    56  
    57  func (h *SuperFlagHelp) String() string {
    58  	defaultLines := make([]string, 0)
    59  	otherLines := make([]string, 0)
    60  	for name, help := range h.flags {
    61  		val, found := h.defaults.m[name]
    62  		line := fmt.Sprintf("    %s=%s; %s\n", name, val, help)
    63  		if found {
    64  			defaultLines = append(defaultLines, line)
    65  		} else {
    66  			otherLines = append(otherLines, line)
    67  		}
    68  	}
    69  	sort.Strings(defaultLines)
    70  	sort.Strings(otherLines)
    71  	dls := strings.Join(defaultLines, "")
    72  	ols := strings.Join(otherLines, "")
    73  	if len(h.defaults.m) == 0 && len(ols) == 0 {
    74  		// remove last newline
    75  		dls = dls[:len(dls)-1]
    76  	}
    77  	// remove last newline
    78  	if len(h.defaults.m) == 0 && len(ols) > 1 {
    79  		ols = ols[:len(ols)-1]
    80  	}
    81  	return h.head + "\n" + dls + ols
    82  }
    83  
    84  func parseFlag(flag string) (map[string]string, error) {
    85  	kvm := make(map[string]string)
    86  	for _, kv := range strings.Split(flag, ";") {
    87  		if strings.TrimSpace(kv) == "" {
    88  			continue
    89  		}
    90  		// For a non-empty separator, 0 < len(splits) ≤ 2.
    91  		splits := strings.SplitN(kv, "=", 2)
    92  		k := strings.TrimSpace(splits[0])
    93  		if len(splits) < 2 {
    94  			return nil, fmt.Errorf("superflag: missing value for '%s' in flag: %s", k, flag)
    95  		}
    96  		k = strings.ToLower(k)
    97  		k = strings.ReplaceAll(k, "_", "-")
    98  		kvm[k] = strings.TrimSpace(splits[1])
    99  	}
   100  	return kvm, nil
   101  }
   102  
   103  type SuperFlag struct {
   104  	m map[string]string
   105  }
   106  
   107  func NewSuperFlag(flag string) *SuperFlag {
   108  	sf, err := newSuperFlagImpl(flag)
   109  	if err != nil {
   110  		log.Fatal(err)
   111  	}
   112  	return sf
   113  }
   114  
   115  func newSuperFlagImpl(flag string) (*SuperFlag, error) {
   116  	m, err := parseFlag(flag)
   117  	if err != nil {
   118  		return nil, err
   119  	}
   120  	return &SuperFlag{m}, nil
   121  }
   122  
   123  func (sf *SuperFlag) String() string {
   124  	if sf == nil {
   125  		return ""
   126  	}
   127  	kvs := make([]string, 0, len(sf.m))
   128  	for k, v := range sf.m {
   129  		kvs = append(kvs, fmt.Sprintf("%s=%s", k, v))
   130  	}
   131  	return strings.Join(kvs, "; ")
   132  }
   133  
   134  func (sf *SuperFlag) MergeAndCheckDefault(flag string) *SuperFlag {
   135  	sf, err := sf.MergeWithDefault(flag)
   136  	if err != nil {
   137  		log.Fatal(err)
   138  	}
   139  	return sf
   140  }
   141  
   142  func (sf *SuperFlag) MergeWithDefault(flag string) (*SuperFlag, error) {
   143  	if sf == nil {
   144  		m, err := parseFlag(flag)
   145  		if err != nil {
   146  			return nil, err
   147  		}
   148  		return &SuperFlag{m}, nil
   149  	}
   150  
   151  	src, err := parseFlag(flag)
   152  	if err != nil {
   153  		return nil, err
   154  	}
   155  
   156  	numKeys := len(sf.m)
   157  	for k := range src {
   158  		if _, ok := sf.m[k]; ok {
   159  			numKeys--
   160  		}
   161  	}
   162  	if numKeys != 0 {
   163  		return nil, fmt.Errorf("superflag: found invalid options in flag: %s.\nvalid options: %v", sf, flag)
   164  	}
   165  	for k, v := range src {
   166  		if _, ok := sf.m[k]; !ok {
   167  			sf.m[k] = v
   168  		}
   169  	}
   170  	return sf, nil
   171  }
   172  
   173  func (sf *SuperFlag) Has(opt string) bool {
   174  	val := sf.GetString(opt)
   175  	return val != ""
   176  }
   177  
   178  func (sf *SuperFlag) GetDuration(opt string) time.Duration {
   179  	val := sf.GetString(opt)
   180  	if val == "" {
   181  		return time.Duration(0)
   182  	}
   183  	if strings.Contains(val, "d") {
   184  		val = strings.Replace(val, "d", "", 1)
   185  		days, err := strconv.ParseUint(val, 0, 64)
   186  		if err != nil {
   187  			return time.Duration(0)
   188  		}
   189  		return time.Hour * 24 * time.Duration(days)
   190  	}
   191  	d, err := time.ParseDuration(val)
   192  	if err != nil {
   193  		return time.Duration(0)
   194  	}
   195  	return d
   196  }
   197  
   198  func (sf *SuperFlag) GetBool(opt string) bool {
   199  	val := sf.GetString(opt)
   200  	if val == "" {
   201  		return false
   202  	}
   203  	b, err := strconv.ParseBool(val)
   204  	if err != nil {
   205  		log.Fatalf("unable to parse %s as bool for key: %s. Options: %s: %v",
   206  			val, opt, sf, err)
   207  	}
   208  	return b
   209  }
   210  
   211  func (sf *SuperFlag) GetFloat64(opt string) float64 {
   212  	val := sf.GetString(opt)
   213  	if val == "" {
   214  		return 0
   215  	}
   216  	f, err := strconv.ParseFloat(val, 64)
   217  	if err != nil {
   218  		log.Fatalf("unable to parse %s as float64 for key: %s. Options: %s: %v",
   219  			val, opt, sf, err)
   220  	}
   221  	return f
   222  }
   223  
   224  func (sf *SuperFlag) GetInt64(opt string) int64 {
   225  	val := sf.GetString(opt)
   226  	if val == "" {
   227  		return 0
   228  	}
   229  	i, err := strconv.ParseInt(val, 0, 64)
   230  	if err != nil {
   231  		log.Fatalf("unable to parse %s as int64 for key: %s. Options: %s: %v",
   232  			val, opt, sf, err)
   233  	}
   234  	return i
   235  }
   236  
   237  func (sf *SuperFlag) GetUint64(opt string) uint64 {
   238  	val := sf.GetString(opt)
   239  	if val == "" {
   240  		return 0
   241  	}
   242  	u, err := strconv.ParseUint(val, 0, 64)
   243  	if err != nil {
   244  		log.Fatalf("unable to parse %s as uint64 for key: %s. Options: %s: %v",
   245  			val, opt, sf, err)
   246  	}
   247  	return u
   248  }
   249  
   250  func (sf *SuperFlag) GetUint32(opt string) uint32 {
   251  	val := sf.GetString(opt)
   252  	if val == "" {
   253  		return 0
   254  	}
   255  	u, err := strconv.ParseUint(val, 0, 32)
   256  	if err != nil {
   257  		log.Fatalf("unable to parse %s as uint32 for key: %s. Options: %s: %v",
   258  			val, opt, sf, err)
   259  	}
   260  	return uint32(u)
   261  }
   262  
   263  func (sf *SuperFlag) GetString(opt string) string {
   264  	if sf == nil {
   265  		return ""
   266  	}
   267  	return sf.m[opt]
   268  }
   269  
   270  func (sf *SuperFlag) GetPath(opt string) string {
   271  	p := sf.GetString(opt)
   272  	path, err := expandPath(p)
   273  	if err != nil {
   274  		log.Fatalf("Failed to get path: %+v", err)
   275  	}
   276  	return path
   277  }
   278  
   279  // expandPath expands the paths containing ~ to /home/user. It also computes the absolute path
   280  // from the relative paths. For example: ~/abc/../cef will be transformed to /home/user/cef.
   281  func expandPath(path string) (string, error) {
   282  	if len(path) == 0 {
   283  		return "", nil
   284  	}
   285  	if path[0] == '~' && (len(path) == 1 || os.IsPathSeparator(path[1])) {
   286  		usr, err := user.Current()
   287  		if err != nil {
   288  			return "", fmt.Errorf("failed to get the home directory of the user: %w", err)
   289  		}
   290  		path = filepath.Join(usr.HomeDir, path[1:])
   291  	}
   292  
   293  	var err error
   294  	path, err = filepath.Abs(path)
   295  	if err != nil {
   296  		return "", fmt.Errorf("failed to generate absolute path: %w", err)
   297  	}
   298  
   299  	return path, nil
   300  }