github.com/outcaste-io/ristretto@v0.2.3/z/flags.go (about)

     1  package z
     2  
     3  import (
     4  	"fmt"
     5  	"os"
     6  	"os/user"
     7  	"path/filepath"
     8  	"sort"
     9  	"strconv"
    10  	"strings"
    11  	"time"
    12  
    13  	"github.com/pkg/errors"
    14  )
    15  
    16  // SuperFlagHelp makes it really easy to generate command line `--help` output for a SuperFlag. For
    17  // example:
    18  //
    19  //	const flagDefaults = `enabled=true; path=some/path;`
    20  //
    21  //	var help string = z.NewSuperFlagHelp(flagDefaults).
    22  //		Flag("enabled", "Turns on <something>.").
    23  //		Flag("path", "The path to <something>.").
    24  //		Flag("another", "Not present in defaults, but still included.").
    25  //		String()
    26  //
    27  // The `help` string would then contain:
    28  //
    29  //	enabled=true; Turns on <something>.
    30  //	path=some/path; The path to <something>.
    31  //	another=; Not present in defaults, but still included.
    32  //
    33  // All flags are sorted alphabetically for consistent `--help` output. Flags with default values are
    34  // placed at the top, and everything else goes under.
    35  type SuperFlagHelp struct {
    36  	head     string
    37  	defaults *SuperFlag
    38  	flags    map[string]string
    39  }
    40  
    41  func NewSuperFlagHelp(defaults string) *SuperFlagHelp {
    42  	return &SuperFlagHelp{
    43  		defaults: NewSuperFlag(defaults),
    44  		flags:    make(map[string]string, 0),
    45  	}
    46  }
    47  
    48  func (h *SuperFlagHelp) Head(head string) *SuperFlagHelp {
    49  	h.head = head
    50  	return h
    51  }
    52  
    53  func (h *SuperFlagHelp) Flag(name, description string) *SuperFlagHelp {
    54  	h.flags[name] = description
    55  	return h
    56  }
    57  
    58  func (h *SuperFlagHelp) String() string {
    59  	defaultLines := make([]string, 0)
    60  	otherLines := make([]string, 0)
    61  	for name, help := range h.flags {
    62  		val, found := h.defaults.m[name]
    63  		line := fmt.Sprintf("    %s=%s; %s\n", name, val, help)
    64  		if found {
    65  			defaultLines = append(defaultLines, line)
    66  		} else {
    67  			otherLines = append(otherLines, line)
    68  		}
    69  	}
    70  	sort.Strings(defaultLines)
    71  	sort.Strings(otherLines)
    72  	dls := strings.Join(defaultLines, "")
    73  	ols := strings.Join(otherLines, "")
    74  	if len(h.defaults.m) == 0 && len(ols) == 0 {
    75  		// remove last newline
    76  		dls = dls[:len(dls)-1]
    77  	}
    78  	// remove last newline
    79  	if len(h.defaults.m) == 0 && len(ols) > 1 {
    80  		ols = ols[:len(ols)-1]
    81  	}
    82  	return h.head + "\n" + dls + ols
    83  }
    84  
    85  func parseFlag(flag string) (map[string]string, error) {
    86  	kvm := make(map[string]string)
    87  	for _, kv := range strings.Split(flag, ";") {
    88  		if strings.TrimSpace(kv) == "" {
    89  			continue
    90  		}
    91  		// For a non-empty separator, 0 < len(splits) ≤ 2.
    92  		splits := strings.SplitN(kv, "=", 2)
    93  		k := strings.TrimSpace(splits[0])
    94  		if len(splits) < 2 {
    95  			return nil, fmt.Errorf("superflag: missing value for '%s' in flag: %s", k, flag)
    96  		}
    97  		k = strings.ToLower(k)
    98  		k = strings.ReplaceAll(k, "_", "-")
    99  		kvm[k] = strings.TrimSpace(splits[1])
   100  	}
   101  	return kvm, nil
   102  }
   103  
   104  type SuperFlag struct {
   105  	m map[string]string
   106  }
   107  
   108  func NewSuperFlag(flag string) *SuperFlag {
   109  	sf, err := newSuperFlagImpl(flag)
   110  	if err != nil {
   111  		fatal(err)
   112  	}
   113  	return sf
   114  }
   115  
   116  func newSuperFlagImpl(flag string) (*SuperFlag, error) {
   117  	m, err := parseFlag(flag)
   118  	if err != nil {
   119  		return nil, err
   120  	}
   121  	return &SuperFlag{m}, nil
   122  }
   123  
   124  func (sf *SuperFlag) String() string {
   125  	if sf == nil {
   126  		return ""
   127  	}
   128  	kvs := make([]string, 0, len(sf.m))
   129  	for k, v := range sf.m {
   130  		kvs = append(kvs, fmt.Sprintf("%s=%s", k, v))
   131  	}
   132  	return strings.Join(kvs, "; ")
   133  }
   134  
   135  func (sf *SuperFlag) MergeAndCheckDefault(flag string) *SuperFlag {
   136  	sf, err := sf.MergeWithDefault(flag)
   137  	if err != nil {
   138  		fatal(err)
   139  	}
   140  	return sf
   141  }
   142  
   143  func (sf *SuperFlag) Merge(flag string) *SuperFlag {
   144  	src, err := parseFlag(flag)
   145  	if err != nil {
   146  		fatal(err)
   147  	}
   148  	for k, v := range src {
   149  		if _, ok := sf.m[k]; !ok {
   150  			fatal("Unable to find the flag in SuperFlag")
   151  		}
   152  		sf.m[k] = v
   153  	}
   154  	return sf
   155  }
   156  
   157  func (sf *SuperFlag) MergeWithDefault(flag string) (*SuperFlag, error) {
   158  	if sf == nil {
   159  		m, err := parseFlag(flag)
   160  		if err != nil {
   161  			return nil, err
   162  		}
   163  		return &SuperFlag{m}, nil
   164  	}
   165  
   166  	src, err := parseFlag(flag)
   167  	if err != nil {
   168  		return nil, err
   169  	}
   170  
   171  	numKeys := len(sf.m)
   172  	for k := range src {
   173  		if _, ok := sf.m[k]; ok {
   174  			numKeys--
   175  		}
   176  	}
   177  	if numKeys != 0 {
   178  		return nil, fmt.Errorf("superflag: found invalid options in flag: %s.\nvalid options: %v", sf, flag)
   179  	}
   180  	for k, v := range src {
   181  		if _, ok := sf.m[k]; !ok {
   182  			sf.m[k] = v
   183  		}
   184  	}
   185  	return sf, nil
   186  }
   187  
   188  func (sf *SuperFlag) Has(opt string) bool {
   189  	val := sf.GetString(opt)
   190  	return val != ""
   191  }
   192  
   193  func (sf *SuperFlag) GetDuration(opt string) time.Duration {
   194  	val := sf.GetString(opt)
   195  	if val == "" {
   196  		return time.Duration(0)
   197  	}
   198  	if strings.Contains(val, "d") {
   199  		val = strings.Replace(val, "d", "", 1)
   200  		days, err := strconv.ParseUint(val, 0, 64)
   201  		if err != nil {
   202  			return time.Duration(0)
   203  		}
   204  		return time.Hour * 24 * time.Duration(days)
   205  	}
   206  	d, err := time.ParseDuration(val)
   207  	if err != nil {
   208  		return time.Duration(0)
   209  	}
   210  	return d
   211  }
   212  
   213  func (sf *SuperFlag) GetBool(opt string) bool {
   214  	val := sf.GetString(opt)
   215  	if val == "" {
   216  		return false
   217  	}
   218  	b, err := strconv.ParseBool(val)
   219  	if err != nil {
   220  		err = errors.Wrapf(err,
   221  			"Unable to parse %s as bool for key: %s. Options: %s\n",
   222  			val, opt, sf)
   223  		fatalf("%+v", err)
   224  	}
   225  	return b
   226  }
   227  
   228  func (sf *SuperFlag) GetFloat64(opt string) float64 {
   229  	val := sf.GetString(opt)
   230  	if val == "" {
   231  		return 0
   232  	}
   233  	f, err := strconv.ParseFloat(val, 64)
   234  	if err != nil {
   235  		err = errors.Wrapf(err,
   236  			"Unable to parse %s as float64 for key: %s. Options: %s\n",
   237  			val, opt, sf)
   238  		fatalf("%+v", err)
   239  	}
   240  	return f
   241  }
   242  
   243  func (sf *SuperFlag) GetInt64(opt string) int64 {
   244  	val := sf.GetString(opt)
   245  	if val == "" {
   246  		return 0
   247  	}
   248  	i, err := strconv.ParseInt(val, 0, 64)
   249  	if err != nil {
   250  		err = errors.Wrapf(err,
   251  			"Unable to parse %s as int64 for key: %s. Options: %s\n",
   252  			val, opt, sf)
   253  		fatalf("%+v", err)
   254  	}
   255  	return i
   256  }
   257  
   258  func (sf *SuperFlag) GetUint64(opt string) uint64 {
   259  	val := sf.GetString(opt)
   260  	if val == "" {
   261  		return 0
   262  	}
   263  	u, err := strconv.ParseUint(val, 0, 64)
   264  	if err != nil {
   265  		err = errors.Wrapf(err,
   266  			"Unable to parse %s as uint64 for key: %s. Options: %s\n",
   267  			val, opt, sf)
   268  		fatalf("%+v", err)
   269  	}
   270  	return u
   271  }
   272  
   273  func (sf *SuperFlag) GetUint32(opt string) uint32 {
   274  	val := sf.GetString(opt)
   275  	if val == "" {
   276  		return 0
   277  	}
   278  	u, err := strconv.ParseUint(val, 0, 32)
   279  	if err != nil {
   280  		err = errors.Wrapf(err,
   281  			"Unable to parse %s as uint32 for key: %s. Options: %s\n",
   282  			val, opt, sf)
   283  		fatalf("%+v", err)
   284  	}
   285  	return uint32(u)
   286  }
   287  
   288  func (sf *SuperFlag) GetString(opt string) string {
   289  	if sf == nil {
   290  		return ""
   291  	}
   292  	return sf.m[opt]
   293  }
   294  
   295  func (sf *SuperFlag) GetPath(opt string) string {
   296  	p := sf.GetString(opt)
   297  	path, err := expandPath(p)
   298  	if err != nil {
   299  		fatalf("Failed to get path: %+v", err)
   300  	}
   301  	return path
   302  }
   303  
   304  // expandPath expands the paths containing ~ to /home/user. It also computes the absolute path
   305  // from the relative paths. For example: ~/abc/../cef will be transformed to /home/user/cef.
   306  func expandPath(path string) (string, error) {
   307  	if len(path) == 0 {
   308  		return "", nil
   309  	}
   310  	if path[0] == '~' && (len(path) == 1 || os.IsPathSeparator(path[1])) {
   311  		usr, err := user.Current()
   312  		if err != nil {
   313  			return "", errors.Wrap(err, "Failed to get the home directory of the user")
   314  		}
   315  		path = filepath.Join(usr.HomeDir, path[1:])
   316  	}
   317  
   318  	var err error
   319  	path, err = filepath.Abs(path)
   320  	if err != nil {
   321  		return "", errors.Wrap(err, "Failed to generate absolute path")
   322  	}
   323  	return path, nil
   324  }