github.com/puellanivis/breton@v0.2.16/lib/gnuflag/struct.go (about)

     1  package gnuflag
     2  
     3  import (
     4  	"fmt"
     5  	"net/url"
     6  	"reflect"
     7  	"strconv"
     8  	"strings"
     9  	"time"
    10  	"unicode"
    11  	"unicode/utf8"
    12  	"unsafe"
    13  )
    14  
    15  // flagName takes a variable name and produces from "FlagNameWithDashes" a "flag-name-with-dashes".
    16  // It also attempts to detect acronyms all in upper case, as is Go style.
    17  // It is _best effort_ and may entirely mangle your flagname,
    18  // e.g. "HTTPURL" will be interpreted as a single acronym.
    19  // It exists solely to give a better default than strings.ToLower(),
    20  // it should almost certainly be overridden with an intentional name.
    21  func flagName(name string) string {
    22  	var words []string
    23  	var word []rune
    24  	var maybeAcronym bool
    25  
    26  	for _, r := range name {
    27  		switch {
    28  		case unicode.IsUpper(r):
    29  			if !maybeAcronym && len(word) > 0 {
    30  				words = append(words, string(word))
    31  				word = word[:0] // reuse previous allocated slice.
    32  			}
    33  
    34  			maybeAcronym = true
    35  			word = append(word, unicode.ToLower(r))
    36  
    37  		case maybeAcronym && len(word) > 1: // an acronym can only be from two uppercase letters together.
    38  			l := len(word) - 1
    39  
    40  			words = append(words, string(word[:l]))
    41  
    42  			word[0] = word[l]
    43  			word = word[:1]
    44  			fallthrough
    45  
    46  		default:
    47  			maybeAcronym = false
    48  			word = append(word, r)
    49  		}
    50  	}
    51  
    52  	if len(word) > 0 {
    53  		words = append(words, string(word))
    54  	}
    55  
    56  	return strings.Join(words, "-")
    57  }
    58  
    59  // structVar is the work-horse, and does the actual reflection and recursive work.
    60  func (fs *FlagSet) structVar(prefix string, v reflect.Value) error {
    61  	if strings.Contains(prefix, "=") || strings.HasPrefix(prefix, "-") {
    62  		return fmt.Errorf("invalid prefix: %q", prefix)
    63  	}
    64  
    65  	structType := v.Type()
    66  
    67  	for i := 0; i < structType.NumField(); i++ {
    68  		val := v.Field(i)
    69  		if !val.CanSet() {
    70  			continue
    71  		}
    72  
    73  		field := structType.Field(i)
    74  		name := flagName(field.Name)
    75  
    76  		usage := field.Tag.Get("desc")
    77  		if usage == "" {
    78  			usage = fmt.Sprintf("%s `%s`", field.Name, field.Type)
    79  		}
    80  
    81  		var short rune
    82  		defval := field.Tag.Get("default")
    83  
    84  		if tag := field.Tag.Get("flag"); tag != "" {
    85  			directives := strings.Split(tag, ",")
    86  
    87  			if len(directives) >= 1 {
    88  				// sanity check: by documentation this should always be true.
    89  
    90  				if directives[0] != "" {
    91  					name = directives[0]
    92  				}
    93  			}
    94  
    95  			if name == "-" {
    96  				continue
    97  			}
    98  
    99  		directivesLoop:
   100  			for j := 1; j < len(directives); j++ {
   101  				directive := directives[j]
   102  
   103  				switch {
   104  				case strings.HasPrefix(directive, "short="):
   105  					short, _ = utf8.DecodeRuneInString(strings.TrimPrefix(directive, "short="))
   106  
   107  				case strings.HasPrefix(directive, "default="):
   108  					defval = strings.TrimPrefix(directive, "default=")
   109  					if j+1 < len(directives) {
   110  						// Commas aren't escaped, and default is defined to be last.
   111  						defval += "," + strings.Join(directives[j+1:], ",")
   112  						break directivesLoop
   113  					}
   114  
   115  				case strings.HasPrefix(directive, "def="):
   116  					defval = strings.TrimPrefix(directive, "def=")
   117  					if j+1 < len(directives) {
   118  						// Commas aren't escaped, and def is defined to be last.
   119  						defval += "," + strings.Join(directives[j+1:], ",")
   120  						break directivesLoop
   121  					}
   122  				}
   123  			}
   124  		}
   125  
   126  		if prefix != "" {
   127  			name = prefix + "-" + name
   128  		}
   129  
   130  		if strings.Contains(name, "=") || strings.HasPrefix(name, "-") {
   131  			return fmt.Errorf("invalid flag name for field %s: %s", field.Name, name)
   132  		}
   133  
   134  		switch val.Kind() {
   135  		case reflect.Ptr:
   136  			if val.IsNil() {
   137  				// If the pointer is nil, then allocate the appropriate type
   138  				// and assign it into the pointer.
   139  				p := reflect.New(val.Type().Elem())
   140  				val.Set(p)
   141  			}
   142  
   143  			// We prefer to work with direct non-pointer values... unless it implements Value.
   144  
   145  			if _, ok := val.Interface().(Value); !ok {
   146  				// If val does not implement Value, dereference it,
   147  				// to work with the direct value.
   148  				val = val.Elem()
   149  			}
   150  
   151  		default:
   152  			// We prefer to use something that implements Value.
   153  			// So, here we reference the value, if:
   154  			// * the value does not implement Value,
   155  			// * the value’s pointer type does implement Value.
   156  			if _, ok := val.Interface().(Value); !ok {
   157  				// ensure that the value itself does not implement Value.
   158  
   159  				pval := val.Addr() // reference the value.
   160  				if _, ok := pval.Interface().(Value); ok {
   161  					// if the pointer implements Value, then let's use that.
   162  					val = pval
   163  				}
   164  			}
   165  		}
   166  
   167  		// We set value such that we can generically just use fs.Var to setup the flag,
   168  		// any other FlagSet.TypeVar will overwrite the value that is stored in that field,
   169  		// which means we wouldn’t get that value as the default.
   170  		// But we want the value in the field as default, even if no `flag:",default=val"` is given.
   171  		var value Value
   172  
   173  		switch v := val.Interface().(type) {
   174  		case EnumValue: // EnumValues implements Value, so we need to check this first.
   175  			enum := &enumValue{
   176  				val: (*int)(unsafe.Pointer(val.UnsafeAddr())),
   177  			}
   178  			value = enum
   179  
   180  			if tag := field.Tag.Get("values"); tag != "" {
   181  				enum.setValid(strings.Split(tag, ","))
   182  			}
   183  
   184  		case Value:
   185  			// this is obviously the simplest option… the work is already done.
   186  			value = v
   187  
   188  		case bool:
   189  			value = (*boolValue)(unsafe.Pointer(val.UnsafeAddr()))
   190  
   191  		case uint:
   192  			value = (*uintValue)(unsafe.Pointer(val.UnsafeAddr()))
   193  		case []uint:
   194  			slice := (*[]uint)(unsafe.Pointer(val.UnsafeAddr()))
   195  
   196  			value = newSlice(slice, func(s string) error {
   197  				u, err := strconv.ParseUint(s, 0, strconv.IntSize)
   198  				if err != nil {
   199  					return err
   200  				}
   201  
   202  				*slice = append(*slice, uint(u))
   203  				return nil
   204  			})
   205  
   206  		case uint64:
   207  			value = (*uint64Value)(unsafe.Pointer(val.UnsafeAddr()))
   208  		case []uint64:
   209  			slice := (*[]uint64)(unsafe.Pointer(val.UnsafeAddr()))
   210  
   211  			value = newSlice(slice, func(s string) error {
   212  				u, err := strconv.ParseUint(s, 0, 64)
   213  				if err != nil {
   214  					return err
   215  				}
   216  
   217  				*slice = append(*slice, u)
   218  				return nil
   219  			})
   220  
   221  		case uint8, uint16, uint32:
   222  			width := val.Type().Size() * 8
   223  
   224  			if defval == "" {
   225  				z := reflect.Zero(val.Type())
   226  				if z.Interface() != val.Interface() {
   227  					defval = fmt.Sprint(val)
   228  				}
   229  			}
   230  
   231  			// here we support a few additional types with generic-ish reflection
   232  			value = newFunc(fmt.Sprint(val.Type()), func(s string) error {
   233  				u, err := strconv.ParseUint(s, 0, int(width))
   234  				if err != nil {
   235  					return err
   236  				}
   237  
   238  				val.SetUint(u)
   239  				return nil
   240  			})
   241  
   242  		case int:
   243  			value = (*intValue)(unsafe.Pointer(val.UnsafeAddr()))
   244  		case []int:
   245  			slice := (*[]int)(unsafe.Pointer(val.UnsafeAddr()))
   246  
   247  			value = newSlice(slice, func(s string) error {
   248  				i, err := strconv.ParseInt(s, 0, strconv.IntSize)
   249  				if err != nil {
   250  					return err
   251  				}
   252  
   253  				*slice = append(*slice, int(i))
   254  				return nil
   255  			})
   256  
   257  		case int64:
   258  			value = (*int64Value)(unsafe.Pointer(val.UnsafeAddr()))
   259  		case []int64:
   260  			slice := (*[]int64)(unsafe.Pointer(val.UnsafeAddr()))
   261  
   262  			value = newSlice(slice, func(s string) error {
   263  				i, err := strconv.ParseInt(s, 0, 64)
   264  				if err != nil {
   265  					return err
   266  				}
   267  
   268  				*slice = append(*slice, i)
   269  				return nil
   270  			})
   271  
   272  		case int8, int16, int32:
   273  			width := val.Type().Size() * 8
   274  
   275  			if defval == "" {
   276  				z := reflect.Zero(val.Type())
   277  				if z.Interface() != val.Interface() {
   278  					defval = fmt.Sprint(val)
   279  				}
   280  			}
   281  
   282  			// here we support a few additional types with generic-ish reflection
   283  			value = newFunc(fmt.Sprint(val.Type()), func(s string) error {
   284  				i, err := strconv.ParseInt(s, 0, int(width))
   285  				if err != nil {
   286  					return err
   287  				}
   288  
   289  				val.SetInt(i)
   290  				return nil
   291  			})
   292  
   293  		case float64:
   294  			value = (*float64Value)(unsafe.Pointer(val.UnsafeAddr()))
   295  		case []float64:
   296  			slice := (*[]float64)(unsafe.Pointer(val.UnsafeAddr()))
   297  
   298  			value = newSlice(slice, func(s string) error {
   299  				f, err := strconv.ParseFloat(s, 64)
   300  				if err != nil {
   301  					return err
   302  				}
   303  
   304  				*slice = append(*slice, f)
   305  				return nil
   306  			})
   307  
   308  		case float32:
   309  			if defval == "" {
   310  				z := reflect.Zero(val.Type())
   311  				if z.Interface() != val.Interface() {
   312  					defval = fmt.Sprint(val)
   313  				}
   314  			}
   315  
   316  			// here we support float32 with generic-ish reflection
   317  			value = newFunc("float32", func(s string) error {
   318  				f, err := strconv.ParseFloat(s, 32)
   319  				if err != nil {
   320  					return err
   321  				}
   322  
   323  				val.SetFloat(f)
   324  				return nil
   325  			})
   326  
   327  		case string:
   328  			value = (*stringValue)(unsafe.Pointer(val.UnsafeAddr()))
   329  		case []string:
   330  			slice := (*[]string)(unsafe.Pointer(val.UnsafeAddr()))
   331  
   332  			value = newSlice(slice, func(s string) error {
   333  				*slice = append(*slice, s)
   334  				return nil
   335  			})
   336  
   337  		case []byte:
   338  			// just like string, but stored as []byte
   339  			value = newFunc(fmt.Sprint(field), func(s string) error {
   340  				val.SetBytes([]byte(s))
   341  				return nil
   342  			})
   343  
   344  		case time.Duration:
   345  			value = (*durationValue)(unsafe.Pointer(val.UnsafeAddr()))
   346  		case []time.Duration:
   347  			slice := (*[]time.Duration)(unsafe.Pointer(val.UnsafeAddr()))
   348  
   349  			value = newSlice(slice, func(s string) error {
   350  				d, err := time.ParseDuration(s)
   351  				if err != nil {
   352  					return err
   353  				}
   354  
   355  				*slice = append(*slice, d)
   356  				return nil
   357  			})
   358  
   359  		// From our code above, we already dereferenced pointers, so this is why not `*url.URL`
   360  		case url.URL:
   361  			set := (*url.URL)(unsafe.Pointer(val.UnsafeAddr()))
   362  
   363  			if defval == "" {
   364  				z := reflect.Zero(val.Type())
   365  				if z.Interface() != val.Interface() {
   366  					defval = set.String()
   367  				}
   368  			}
   369  
   370  			value = newFunc(fmt.Sprint(field), func(s string) error {
   371  				uri, err := url.Parse(s)
   372  				if err != nil {
   373  					return err
   374  				}
   375  
   376  				*set = *uri
   377  				return nil
   378  			})
   379  		case []*url.URL:
   380  			slice := (*[]*url.URL)(unsafe.Pointer(val.UnsafeAddr()))
   381  
   382  			value = newSlice(slice, func(s string) error {
   383  				uri, err := url.Parse(s)
   384  				if err != nil {
   385  					return err
   386  				}
   387  
   388  				*slice = append(*slice, uri)
   389  				return nil
   390  			})
   391  
   392  		default:
   393  			if val.Kind() != reflect.Struct {
   394  				return fmt.Errorf("gnuflag: unsupported type %q for %q", field.Type, field.Name)
   395  			}
   396  
   397  			if err := fs.structVar(name, val); err != nil {
   398  				return err
   399  			}
   400  
   401  			// Do not setup the fs.Var like all the other paths.
   402  			continue
   403  		}
   404  
   405  		var opts []Option
   406  		if short != 0 {
   407  			opts = append(opts, WithShort(short))
   408  		}
   409  		if defval != "" {
   410  			opts = append(opts, WithDefault(defval))
   411  		}
   412  
   413  		if err := fs.Var(value, name, usage, opts...); err != nil {
   414  			return err
   415  		}
   416  	}
   417  
   418  	return nil
   419  }
   420  
   421  // Struct uses reflection to take a structure and turn it into a series of flags.
   422  // It recognizes the struct tags of `flag:"flag-name,short=F,default=defval"` and `desc:"usage"`.
   423  // The "desc" tag is intended to be much more generic than just for use in this library.
   424  // To ignore a struct value use the tag `flag:"-"`, and `flag:","` will use the variable’s name.
   425  func (fs *FlagSet) Struct(prefix string, value interface{}) error {
   426  	v := reflect.ValueOf(value)
   427  	if v.Kind() != reflect.Ptr || v.IsNil() {
   428  		return fmt.Errorf("gnuflag.FlagSet.Struct on non-pointer: %v", v.Kind())
   429  	}
   430  
   431  	v = v.Elem()
   432  
   433  	if v.Kind() != reflect.Struct {
   434  		return fmt.Errorf("gnuflag.FlagSet.Struct on non-struct: %v", v.Kind())
   435  	}
   436  
   437  	return fs.structVar(prefix, v)
   438  }
   439  
   440  // Struct uses default CommandLine flagset.
   441  func Struct(prefix string, value interface{}) error {
   442  	return CommandLine.Struct(prefix, value)
   443  }