github.com/mitranim/gg@v0.1.17/flag.go (about)

     1  package gg
     2  
     3  import (
     4  	"encoding"
     5  	r "reflect"
     6  )
     7  
     8  /*
     9  Parses CLI flags into an instance of the given type, which must be a struct.
    10  For parsing rules, see `FlagParser`.
    11  */
    12  func FlagParseTo[A any](src []string) (out A) {
    13  	FlagParse(src, &out)
    14  	return
    15  }
    16  
    17  /*
    18  Parses CLI flags into the given value, which must be a struct.
    19  Panics on error. For parsing rules, see `FlagParser`.
    20  */
    21  func FlagParse[A any](src []string, out *A) {
    22  	if out != nil {
    23  		FlagParseReflect(src, r.ValueOf(AnyNoEscUnsafe(out)).Elem())
    24  	}
    25  }
    26  
    27  /*
    28  Parses CLI flags into the given value, which must be a struct.
    29  For parsing rules, see `FlagParser`.
    30  */
    31  func FlagParseCatch[A any](src []string, out *A) (err error) {
    32  	defer Rec(&err)
    33  	FlagParse(src, out)
    34  	return
    35  }
    36  
    37  /*
    38  Parses CLI flags into the given output, which must be a settable struct value.
    39  For parsing rules, see `FlagParser`.
    40  */
    41  func FlagParseReflect(src []string, out r.Value) {
    42  	if !out.IsValid() {
    43  		return
    44  	}
    45  
    46  	var par FlagParser
    47  	par.Init(out)
    48  	par.Args(src)
    49  	par.Default()
    50  }
    51  
    52  /*
    53  Tool for parsing lists of CLI flags into structs. Partial replacement for the
    54  standard library package "flag". Example:
    55  
    56  	type Opt struct {
    57  		Args []string `flag:""`
    58  		Help bool     `flag:"-h"          desc:"Print help and exit."`
    59  		Verb bool     `flag:"-v"          desc:"Verbose logging."`
    60  		Src  string   `flag:"-s" init:"." desc:"Source path."`
    61  		Out  string   `flag:"-o"          desc:"Destination path."`
    62  	}
    63  
    64  	func (self *Opt) Init() {
    65  		gg.FlagParse(os.Args[1:], self)
    66  
    67  		if self.Help {
    68  			log.Println(gg.FlagHelp[Opt]())
    69  			os.Exit(0)
    70  		}
    71  
    72  		if gg.IsZero(self.Out) {
    73  			log.Println(`missing output path: "-o"`)
    74  			os.Exit(1)
    75  		}
    76  	}
    77  
    78  Supported struct tags:
    79  
    80  	* `flag`: must be "" or a valid flag like "-v" or "--verbose".
    81  	  Fields without the `flag` tag are ignored. Flags must be unique.
    82  	* Field with `flag:""` is used for remaining non-flag args.
    83  	  It must have a type convertible to `[]string`.
    84  	* `init`: initial value. Used if the flag was not provided.
    85  	* `desc`: description. Used for help printing.
    86  
    87  Parsing rules:
    88  
    89  	* Supports all primitive types.
    90  	* Supports slices of arbitrary types.
    91  	* Supports `gg.Parser`.
    92  	* Supports `encoding.TextUnmarshaler`.
    93  	* Supports `flag.Value`.
    94  	* Each flag may be listed multiple times.
    95  		* If the target is a parser, invoke its parsing method.
    96  		* If the target is a scalar, replace the old value with the new value.
    97  		* If the target is a slice, append the new value.
    98  */
    99  type FlagParser struct {
   100  	Tar r.Value
   101  	Def FlagDef
   102  	Got Set[string]
   103  }
   104  
   105  /*
   106  Initializes the parser for the given destination, which must be a settable
   107  struct value.
   108  */
   109  func (self *FlagParser) Init(tar r.Value) {
   110  	self.Tar = tar
   111  	self.Def = FlagDefCache.Get(tar.Type())
   112  	self.Got = make(Set[string], len(self.Def.Flags))
   113  }
   114  
   115  /*
   116  Parses the given CLI args into the destination. May be called multiple times.
   117  Must be called after `(*FlagParser).Init`, and before `FlagParser.Default`.
   118  */
   119  func (self FlagParser) Args(src []string) {
   120  	for IsNotEmpty(src) {
   121  		if !isCliFlag(Head(src)) {
   122  			self.SetArgs(src)
   123  			return
   124  		}
   125  
   126  		head := PopHead(&src)
   127  		key, val, split := cliFlagSplit(head)
   128  		if split {
   129  			self.Got.Add(key)
   130  			self.Flag(key, val)
   131  			continue
   132  		}
   133  
   134  		self.Got.Add(head)
   135  
   136  		if IsEmpty(src) || isCliFlag(Head(src)) {
   137  			self.TrailingFlag(head)
   138  			continue
   139  		}
   140  		if self.TrailingBool(head) {
   141  			continue
   142  		}
   143  
   144  		self.Flag(head, PopHead(&src))
   145  	}
   146  }
   147  
   148  // For internal use.
   149  func (self FlagParser) SetArgs(src []string) {
   150  	field := self.Def.Args
   151  
   152  	if field.IsValid() {
   153  		self.Tar.
   154  			FieldByIndex(field.Index).
   155  			Set(r.ValueOf(src).Convert(field.Type))
   156  		return
   157  	}
   158  
   159  	if IsEmpty(src) {
   160  		return
   161  	}
   162  
   163  	panic(Errf(`unexpected non-flag args: %q`, src))
   164  }
   165  
   166  // For internal use.
   167  func (self FlagParser) FlagField(key string) r.Value {
   168  	return self.Tar.FieldByIndex(self.Def.Get(key).Index)
   169  }
   170  
   171  // For internal use.
   172  func (self FlagParser) Flag(key, src string) {
   173  	self.FieldParse(src, self.FlagField(key))
   174  }
   175  
   176  // For internal use.
   177  func (self FlagParser) FieldParse(src string, out r.Value) {
   178  	var nested bool
   179  
   180  interfaces:
   181  	ptr := out.Addr().Interface()
   182  
   183  	// Part of the `flag.Value` interface.
   184  	setter, _ := ptr.(interface{ Set(string) error })
   185  	if setter != nil {
   186  		Try(setter.Set(src))
   187  		return
   188  	}
   189  
   190  	parser, _ := ptr.(Parser)
   191  	if parser != nil {
   192  		Try(parser.Parse(src))
   193  		return
   194  	}
   195  
   196  	unmarshaler, _ := ptr.(encoding.TextUnmarshaler)
   197  	if unmarshaler != nil {
   198  		Try(unmarshaler.UnmarshalText(ToBytes(src)))
   199  		return
   200  	}
   201  
   202  	if out.Kind() == r.Slice {
   203  		growLenReflect(out)
   204  		out = out.Index(out.Len() - 1)
   205  
   206  		if !nested {
   207  			nested = true
   208  			goto interfaces
   209  		}
   210  	}
   211  
   212  	if out.Kind() == r.Bool && src == `` {
   213  		src = `true`
   214  	}
   215  	Try(ParseReflectCatch(src, out.Addr()))
   216  }
   217  
   218  // For internal use.
   219  func (self FlagParser) TrailingFlag(key string) {
   220  	// TODO: consider supporting various parser interfaces here.
   221  	if self.TrailingBool(key) {
   222  		return
   223  	}
   224  	panic(Errf(`missing value for trailing flag %q`, key))
   225  }
   226  
   227  // For internal use.
   228  func (self FlagParser) TrailingBool(key string) bool {
   229  	/**
   230  	Following the established conventions, bool flags don't support
   231  	"-flag value", only "-flag=value". A boolean flag always terminates
   232  	immediately, without looking for a following space-separated value.
   233  	*/
   234  
   235  	tar := self.FlagField(key)
   236  
   237  	if tar.Kind() == r.Bool {
   238  		tar.SetBool(true)
   239  		return true
   240  	}
   241  
   242  	if tar.Kind() == r.Slice && tar.Type().Elem().Kind() == r.Bool {
   243  		growLenReflect(tar)
   244  		tar.Index(tar.Len() - 1).SetBool(true)
   245  		return true
   246  	}
   247  
   248  	return false
   249  }
   250  
   251  /*
   252  Applies defaults to all flags which have not been found during parsing.
   253  Explicitly providing an empty value suppresses a default, although
   254  an empty string may not be a viable input to some types.
   255  */
   256  func (self FlagParser) Default() {
   257  	for _, field := range self.Def.Flags {
   258  		if !self.Got.Has(field.Flag) {
   259  			if field.InitHas {
   260  				self.FieldParse(field.Init, self.Tar.FieldByIndex(field.Index))
   261  			}
   262  		}
   263  	}
   264  }
   265  
   266  // Returns a help string for the given struct type, using `FlagFmtDefault`.
   267  func FlagHelp[A any]() string {
   268  	return FlagDefCache.Get(Type[A]()).Help()
   269  }
   270  
   271  // Stores cached `FlagDef` definitions for struct types.
   272  var FlagDefCache = TypeCacheOf[FlagDef]()
   273  
   274  /*
   275  Struct type definition suitable for flag parsing. Used internally by
   276  `FlagParser`. User code shouldn't have to use this type, but it's exported for
   277  customization purposes.
   278  */
   279  type FlagDef struct {
   280  	Type  r.Type
   281  	Args  FlagDefField
   282  	Flags []FlagDefField
   283  	Index map[string]int
   284  }
   285  
   286  // For internal use.
   287  func (self *FlagDef) Init(src r.Type) {
   288  	self.Type = src
   289  	Each(StructDeepPublicFieldCache.Get(src), self.AddField)
   290  }
   291  
   292  // For internal use.
   293  func (self *FlagDef) AddField(src r.StructField) {
   294  	var field FlagDefField
   295  	field.Set(src)
   296  	if !field.FlagHas {
   297  		return
   298  	}
   299  
   300  	if MapHas(self.Index, field.Flag) ||
   301  		(field.Flag == `` && self.Args.IsValid()) {
   302  		panic(Errf(`redundant flag %q in type %v`, field.Flag, self.Type))
   303  	}
   304  
   305  	if field.Flag == `` {
   306  		if !field.Type.ConvertibleTo(Type[[]string]()) {
   307  			panic(Errf(
   308  				`invalid type %v in field %q of type %v: args field must be convertible to []string`,
   309  				field.Type, field.Name, self.Type,
   310  			))
   311  		}
   312  		self.Args = field
   313  		return
   314  	}
   315  
   316  	if !isCliFlagValid(field.Flag) {
   317  		panic(Errf(
   318  			`invalid flag %q in field %q of type %v`,
   319  			field.Flag, field.Name, self.Type,
   320  		))
   321  	}
   322  
   323  	MapInit(&self.Index)[field.Flag] = len(self.Flags)
   324  	Append(&self.Flags, field)
   325  }
   326  
   327  // For internal use.
   328  func (self FlagDef) Got(key string) (FlagDefField, bool) {
   329  	ind, ok := self.Index[key]
   330  	if !ok {
   331  		return Zero[FlagDefField](), false
   332  	}
   333  	return Got(self.Flags, ind)
   334  }
   335  
   336  // For internal use.
   337  func (self FlagDef) Get(key string) FlagDefField {
   338  	val, ok := self.Got(key)
   339  	if !ok {
   340  		panic(Errf(`unable to find flag %q in type %v`, key, self.Type))
   341  	}
   342  	return val
   343  }
   344  
   345  // Creates a help string listing the available flags, using `FlagFmtDefault`.
   346  func (self FlagDef) Help() string { return FlagFmtDefault.String(self) }
   347  
   348  // Used internally by `FlagDef`.
   349  type FlagDefField struct {
   350  	r.StructField
   351  	Flag    string
   352  	FlagHas bool
   353  	FlagLen int
   354  	Init    string
   355  	InitHas bool
   356  	InitLen int
   357  	Desc    string
   358  	DescHas bool
   359  	DescLen int
   360  }
   361  
   362  func (self FlagDefField) IsValid() bool { return IsNotZero(self) }
   363  
   364  func (self *FlagDefField) Set(src r.StructField) {
   365  	self.StructField = src
   366  
   367  	self.Flag, self.FlagHas = self.Tag.Lookup(`flag`)
   368  	self.Init, self.InitHas = self.Tag.Lookup(`init`)
   369  	self.Desc = self.Tag.Get(`desc`)
   370  
   371  	self.FlagLen = CharCount(self.Flag)
   372  	self.InitLen = CharCount(self.Init)
   373  	self.DescLen = CharCount(self.Desc)
   374  
   375  	self.DescHas = self.DescLen > 0
   376  }
   377  
   378  func (self FlagDefField) GetFlagHas() bool { return self.FlagHas }
   379  func (self FlagDefField) GetInitHas() bool { return self.InitHas }
   380  func (self FlagDefField) GetDescHas() bool { return self.DescHas }
   381  
   382  func (self FlagDefField) GetFlagLen() int { return self.FlagLen }
   383  func (self FlagDefField) GetInitLen() int { return self.InitLen }
   384  func (self FlagDefField) GetDescLen() int { return self.DescLen }
   385  
   386  // Default help formatter, used by `FlagHelp` and `FlagDef.Help`.
   387  var FlagFmtDefault = With((*FlagFmt).Default)
   388  
   389  /*
   390  Table-like formatter for listing available flags, initial values, and
   391  descriptions. Used via `FlagFmtDefault`, `FlagHelp`, `FlagDef.Help`.
   392  To customize printing, mutate `FlagFmtDefault`.
   393  */
   394  type FlagFmt struct {
   395  	Prefix    string // Prepended before each line.
   396  	Infix     string // Inserted between columns.
   397  	Head      bool   // If true, print table header.
   398  	FlagHead  string // Title for header cell "flag".
   399  	InitHead  string // Title for header cell "init".
   400  	DescHead  string // Title for header cell "desc".
   401  	HeadUnder string // Separator between table header and body.
   402  }
   403  
   404  // Sets default values.
   405  func (self *FlagFmt) Default() {
   406  	self.Infix = `    `
   407  	self.Head = true
   408  	self.FlagHead = `flag`
   409  	self.InitHead = `init`
   410  	self.DescHead = `desc`
   411  	self.HeadUnder = `-`
   412  }
   413  
   414  // Returns a table-like help string for the given definition.
   415  func (self FlagFmt) String(def FlagDef) string {
   416  	return ToString(self.AppendTo(nil, def))
   417  }
   418  
   419  /*
   420  Appends table-like help for the given definition. Known limitation: assumes
   421  monospace, doesn't support wider characters such as kanji or emoji.
   422  */
   423  func (self FlagFmt) AppendTo(src []byte, def FlagDef) []byte {
   424  	flags := def.Flags
   425  	if IsEmpty(flags) {
   426  		return src
   427  	}
   428  
   429  	prefixLen := CharCount(self.Prefix)
   430  	sepLen := CharCount(self.Infix)
   431  	newlineLen := CharCount(Newline)
   432  	flagLen := MaxPrimBy(flags, FlagDefField.GetFlagLen)
   433  
   434  	var flagHeadLen int
   435  	if self.Head {
   436  		flagHeadLen = CharCount(self.FlagHead)
   437  		flagLen = MaxPrim2(flagHeadLen, flagLen)
   438  	}
   439  
   440  	var initHeadLen int
   441  	var initLen int
   442  	if Some(flags, FlagDefField.GetInitHas) {
   443  		if self.Head {
   444  			initHeadLen = CharCount(self.InitHead)
   445  		}
   446  		initLen = MaxPrim2(
   447  			initHeadLen,
   448  			MaxPrimBy(flags, FlagDefField.GetInitLen),
   449  		)
   450  	}
   451  
   452  	var initLenOuter int
   453  	if initLen > 0 {
   454  		initLenOuter = sepLen + initLen
   455  	}
   456  
   457  	var descHeadLen int
   458  	var descLen int
   459  	if Some(flags, FlagDefField.GetDescHas) {
   460  		if self.Head {
   461  			descHeadLen = CharCount(self.DescHead)
   462  		}
   463  		descLen = MaxPrim2(
   464  			descHeadLen,
   465  			MaxPrimBy(flags, FlagDefField.GetDescLen),
   466  		)
   467  	}
   468  
   469  	var descLenOuter int
   470  	if descLen > 0 {
   471  		descLenOuter = sepLen + descLen
   472  	}
   473  
   474  	rowLenInner := prefixLen + flagLen + initLenOuter + descLenOuter
   475  	rowLen := rowLenInner + newlineLen
   476  	headUnderLen := CharCount(self.HeadUnder)
   477  
   478  	buf := Buf(src)
   479  	buf.GrowCap(((2 + len(flags)) * rowLen))
   480  
   481  	if self.Head {
   482  		buf.AppendString(self.Prefix)
   483  
   484  		buf.AppendString(self.FlagHead)
   485  		buf.AppendSpaces(flagLen - flagHeadLen)
   486  		if initLen > 0 || descLen > 0 {
   487  			buf.AppendSpaces(flagLen - flagHeadLen)
   488  		}
   489  
   490  		if initLen > 0 {
   491  			buf.AppendString(self.Infix)
   492  			buf.AppendString(self.InitHead)
   493  			if descLen > 0 {
   494  				buf.AppendSpaces(initLen - initHeadLen)
   495  			}
   496  		}
   497  
   498  		if descLen > 0 {
   499  			buf.AppendString(self.Infix)
   500  			buf.AppendString(self.DescHead)
   501  		}
   502  
   503  		buf.AppendString(Newline)
   504  
   505  		if rowLenInner > 0 && headUnderLen > 0 {
   506  			buf.AppendString(self.Prefix)
   507  			buf.AppendStringN(self.HeadUnder, rowLenInner/headUnderLen)
   508  			buf.AppendString(Newline)
   509  		}
   510  	}
   511  
   512  	for _, field := range flags {
   513  		buf.AppendString(self.Prefix)
   514  
   515  		buf.AppendString(field.Flag)
   516  		if initLen > 0 || descLen > 0 {
   517  			buf.AppendSpaces(flagLen - field.FlagLen)
   518  		}
   519  
   520  		if initLen > 0 {
   521  			buf.AppendString(self.Infix)
   522  			buf.AppendString(field.Init)
   523  			if descLen > 0 {
   524  				buf.AppendSpaces(initLen - field.InitLen)
   525  			}
   526  		}
   527  
   528  		if descLen > 0 {
   529  			buf.AppendString(self.Infix)
   530  			buf.AppendString(field.Desc)
   531  		}
   532  		buf.AppendString(Newline)
   533  	}
   534  
   535  	return buf
   536  }