github.com/fiatjaf/generic-ristretto@v0.0.1/z/flags.go (about)

     1  package z
     2  
     3  import (
     4  	"fmt"
     5  	"os"
     6  	"os/user"
     7  	"path/filepath"
     8  	"sort"
     9  	"strconv"
    10  	"strings"
    11  	"time"
    12  
    13  	"github.com/golang/glog"
    14  	"github.com/pkg/errors"
    15  )
    16  
    17  // SuperFlagHelp makes it really easy to generate command line `--help` output for a SuperFlag. For
    18  // example:
    19  //
    20  //	const flagDefaults = `enabled=true; path=some/path;`
    21  //
    22  //	var help string = z.NewSuperFlagHelp(flagDefaults).
    23  //		Flag("enabled", "Turns on <something>.").
    24  //		Flag("path", "The path to <something>.").
    25  //		Flag("another", "Not present in defaults, but still included.").
    26  //		String()
    27  //
    28  // The `help` string would then contain:
    29  //
    30  //	enabled=true; Turns on <something>.
    31  //	path=some/path; The path to <something>.
    32  //	another=; Not present in defaults, but still included.
    33  //
    34  // All flags are sorted alphabetically for consistent `--help` output. Flags with default values are
    35  // placed at the top, and everything else goes under.
    36  type SuperFlagHelp struct {
    37  	head     string
    38  	defaults *SuperFlag
    39  	flags    map[string]string
    40  }
    41  
    42  func NewSuperFlagHelp(defaults string) *SuperFlagHelp {
    43  	return &SuperFlagHelp{
    44  		defaults: NewSuperFlag(defaults),
    45  		flags:    make(map[string]string, 0),
    46  	}
    47  }
    48  
    49  func (h *SuperFlagHelp) Head(head string) *SuperFlagHelp {
    50  	h.head = head
    51  	return h
    52  }
    53  
    54  func (h *SuperFlagHelp) Flag(name, description string) *SuperFlagHelp {
    55  	h.flags[name] = description
    56  	return h
    57  }
    58  
    59  func (h *SuperFlagHelp) String() string {
    60  	defaultLines := make([]string, 0)
    61  	otherLines := make([]string, 0)
    62  	for name, help := range h.flags {
    63  		val, found := h.defaults.m[name]
    64  		line := fmt.Sprintf("    %s=%s; %s\n", name, val, help)
    65  		if found {
    66  			defaultLines = append(defaultLines, line)
    67  		} else {
    68  			otherLines = append(otherLines, line)
    69  		}
    70  	}
    71  	sort.Strings(defaultLines)
    72  	sort.Strings(otherLines)
    73  	dls := strings.Join(defaultLines, "")
    74  	ols := strings.Join(otherLines, "")
    75  	if len(h.defaults.m) == 0 && len(ols) == 0 {
    76  		// remove last newline
    77  		dls = dls[:len(dls)-1]
    78  	}
    79  	// remove last newline
    80  	if len(h.defaults.m) == 0 && len(ols) > 1 {
    81  		ols = ols[:len(ols)-1]
    82  	}
    83  	return h.head + "\n" + dls + ols
    84  }
    85  
    86  func parseFlag(flag string) (map[string]string, error) {
    87  	kvm := make(map[string]string)
    88  	for _, kv := range strings.Split(flag, ";") {
    89  		if strings.TrimSpace(kv) == "" {
    90  			continue
    91  		}
    92  		// For a non-empty separator, 0 < len(splits) ≤ 2.
    93  		splits := strings.SplitN(kv, "=", 2)
    94  		k := strings.TrimSpace(splits[0])
    95  		if len(splits) < 2 {
    96  			return nil, fmt.Errorf("superflag: missing value for '%s' in flag: %s", k, flag)
    97  		}
    98  		k = strings.ToLower(k)
    99  		k = strings.ReplaceAll(k, "_", "-")
   100  		kvm[k] = strings.TrimSpace(splits[1])
   101  	}
   102  	return kvm, nil
   103  }
   104  
   105  type SuperFlag struct {
   106  	m map[string]string
   107  }
   108  
   109  func NewSuperFlag(flag string) *SuperFlag {
   110  	sf, err := newSuperFlagImpl(flag)
   111  	if err != nil {
   112  		glog.Fatal(err)
   113  	}
   114  	return sf
   115  }
   116  
   117  func newSuperFlagImpl(flag string) (*SuperFlag, error) {
   118  	m, err := parseFlag(flag)
   119  	if err != nil {
   120  		return nil, err
   121  	}
   122  	return &SuperFlag{m}, nil
   123  }
   124  
   125  func (sf *SuperFlag) String() string {
   126  	if sf == nil {
   127  		return ""
   128  	}
   129  	kvs := make([]string, 0, len(sf.m))
   130  	for k, v := range sf.m {
   131  		kvs = append(kvs, fmt.Sprintf("%s=%s", k, v))
   132  	}
   133  	return strings.Join(kvs, "; ")
   134  }
   135  
   136  func (sf *SuperFlag) MergeAndCheckDefault(flag string) *SuperFlag {
   137  	sf, err := sf.mergeAndCheckDefaultImpl(flag)
   138  	if err != nil {
   139  		glog.Fatal(err)
   140  	}
   141  	return sf
   142  }
   143  
   144  func (sf *SuperFlag) mergeAndCheckDefaultImpl(flag string) (*SuperFlag, error) {
   145  	if sf == nil {
   146  		m, err := parseFlag(flag)
   147  		if err != nil {
   148  			return nil, err
   149  		}
   150  		return &SuperFlag{m}, nil
   151  	}
   152  
   153  	src, err := parseFlag(flag)
   154  	if err != nil {
   155  		return nil, err
   156  	}
   157  
   158  	numKeys := len(sf.m)
   159  	for k := range src {
   160  		if _, ok := sf.m[k]; ok {
   161  			numKeys--
   162  		}
   163  	}
   164  	if numKeys != 0 {
   165  		return nil, fmt.Errorf("superflag: found invalid options in flag: %s.\nvalid options: %v", sf, flag)
   166  	}
   167  	for k, v := range src {
   168  		if _, ok := sf.m[k]; !ok {
   169  			sf.m[k] = v
   170  		}
   171  	}
   172  	return sf, nil
   173  }
   174  
   175  func (sf *SuperFlag) Has(opt string) bool {
   176  	val := sf.GetString(opt)
   177  	return val != ""
   178  }
   179  
   180  func (sf *SuperFlag) GetDuration(opt string) time.Duration {
   181  	val := sf.GetString(opt)
   182  	if val == "" {
   183  		return time.Duration(0)
   184  	}
   185  	if strings.Contains(val, "d") {
   186  		val = strings.Replace(val, "d", "", 1)
   187  		days, err := strconv.ParseUint(val, 0, 64)
   188  		if err != nil {
   189  			return time.Duration(0)
   190  		}
   191  		return time.Hour * 24 * time.Duration(days)
   192  	}
   193  	d, err := time.ParseDuration(val)
   194  	if err != nil {
   195  		return time.Duration(0)
   196  	}
   197  	return d
   198  }
   199  
   200  func (sf *SuperFlag) GetBool(opt string) bool {
   201  	val := sf.GetString(opt)
   202  	if val == "" {
   203  		return false
   204  	}
   205  	b, err := strconv.ParseBool(val)
   206  	if err != nil {
   207  		err = errors.Wrapf(err,
   208  			"Unable to parse %s as bool for key: %s. Options: %s\n",
   209  			val, opt, sf)
   210  		glog.Fatalf("%+v", err)
   211  	}
   212  	return b
   213  }
   214  
   215  func (sf *SuperFlag) GetFloat64(opt string) float64 {
   216  	val := sf.GetString(opt)
   217  	if val == "" {
   218  		return 0
   219  	}
   220  	f, err := strconv.ParseFloat(val, 64)
   221  	if err != nil {
   222  		err = errors.Wrapf(err,
   223  			"Unable to parse %s as float64 for key: %s. Options: %s\n",
   224  			val, opt, sf)
   225  		glog.Fatalf("%+v", err)
   226  	}
   227  	return f
   228  }
   229  
   230  func (sf *SuperFlag) GetInt64(opt string) int64 {
   231  	val := sf.GetString(opt)
   232  	if val == "" {
   233  		return 0
   234  	}
   235  	i, err := strconv.ParseInt(val, 0, 64)
   236  	if err != nil {
   237  		err = errors.Wrapf(err,
   238  			"Unable to parse %s as int64 for key: %s. Options: %s\n",
   239  			val, opt, sf)
   240  		glog.Fatalf("%+v", err)
   241  	}
   242  	return i
   243  }
   244  
   245  func (sf *SuperFlag) GetUint64(opt string) uint64 {
   246  	val := sf.GetString(opt)
   247  	if val == "" {
   248  		return 0
   249  	}
   250  	u, err := strconv.ParseUint(val, 0, 64)
   251  	if err != nil {
   252  		err = errors.Wrapf(err,
   253  			"Unable to parse %s as uint64 for key: %s. Options: %s\n",
   254  			val, opt, sf)
   255  		glog.Fatalf("%+v", err)
   256  	}
   257  	return u
   258  }
   259  
   260  func (sf *SuperFlag) GetUint32(opt string) uint32 {
   261  	val := sf.GetString(opt)
   262  	if val == "" {
   263  		return 0
   264  	}
   265  	u, err := strconv.ParseUint(val, 0, 32)
   266  	if err != nil {
   267  		err = errors.Wrapf(err,
   268  			"Unable to parse %s as uint32 for key: %s. Options: %s\n",
   269  			val, opt, sf)
   270  		glog.Fatalf("%+v", err)
   271  	}
   272  	return uint32(u)
   273  }
   274  
   275  func (sf *SuperFlag) GetString(opt string) string {
   276  	if sf == nil {
   277  		return ""
   278  	}
   279  	return sf.m[opt]
   280  }
   281  
   282  func (sf *SuperFlag) GetPath(opt string) string {
   283  	p := sf.GetString(opt)
   284  	path, err := expandPath(p)
   285  	if err != nil {
   286  		glog.Fatalf("Failed to get path: %+v", err)
   287  	}
   288  	return path
   289  }
   290  
   291  // expandPath expands the paths containing ~ to /home/user. It also computes the absolute path
   292  // from the relative paths. For example: ~/abc/../cef will be transformed to /home/user/cef.
   293  func expandPath(path string) (string, error) {
   294  	if len(path) == 0 {
   295  		return "", nil
   296  	}
   297  	if path[0] == '~' && (len(path) == 1 || os.IsPathSeparator(path[1])) {
   298  		usr, err := user.Current()
   299  		if err != nil {
   300  			return "", errors.Wrap(err, "Failed to get the home directory of the user")
   301  		}
   302  		path = filepath.Join(usr.HomeDir, path[1:])
   303  	}
   304  
   305  	var err error
   306  	path, err = filepath.Abs(path)
   307  	if err != nil {
   308  		return "", errors.Wrap(err, "Failed to generate absolute path")
   309  	}
   310  	return path, nil
   311  }