github.com/octohelm/cuemod@v0.9.4/pkg/cli/app.go (about)

     1  package cli
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"os"
     7  	"reflect"
     8  	"strconv"
     9  	"strings"
    10  
    11  	"github.com/spf13/cobra"
    12  	"github.com/spf13/pflag"
    13  )
    14  
    15  func NewApp(name string, version string, flags ...any) *App {
    16  	return &App{
    17  		Name:    Name{Name: name},
    18  		Version: version,
    19  		flags:   flags,
    20  	}
    21  }
    22  
    23  type App struct {
    24  	Name
    25  	Version string
    26  	flags   []any
    27  }
    28  
    29  func (a *App) PreRun(ctx context.Context) context.Context {
    30  	for i := range a.flags {
    31  		if preRun, ok := a.flags[i].(CanPreRun); ok {
    32  			ctx = preRun.PreRun(ctx)
    33  		}
    34  	}
    35  	return ctx
    36  }
    37  
    38  func (a *App) Run(ctx context.Context, args []string) error {
    39  	app := a.newCmdFrom(a, nil)
    40  
    41  	// bind global flags
    42  	for i := range a.flags {
    43  		a.bindCommand(app, a.flags[i], a.Naming())
    44  	}
    45  
    46  	return app.ExecuteContext(ctx)
    47  }
    48  
    49  func (a *App) newCmdFrom(cc Command, parent Command) *cobra.Command {
    50  	name := cc.Naming()
    51  	name.parent = parent
    52  
    53  	c := &cobra.Command{
    54  		Version: a.Version,
    55  	}
    56  
    57  	a.bindCommand(c, cc, name)
    58  
    59  	c.Args = func(cmd *cobra.Command, args []string) error {
    60  		return name.ValidArgs.Validate(args)
    61  	}
    62  
    63  	c.RunE = func(cmd *cobra.Command, args []string) error {
    64  		ctx := cmd.Context()
    65  
    66  		// run parent PreRun if exists
    67  		parents := make([]Command, 0)
    68  		for p := parent; p != nil; p = p.Naming().parent {
    69  			parents = append(parents, p)
    70  		}
    71  		for i := range parents {
    72  			if canPreRun, ok := parents[len(parents)-1-i].(CanPreRun); ok {
    73  				ctx = canPreRun.PreRun(ctx)
    74  			}
    75  		}
    76  
    77  		if preRun, ok := cc.(CanPreRun); ok {
    78  			ctx = preRun.PreRun(ctx)
    79  		}
    80  
    81  		return cc.Run(ctx, args)
    82  	}
    83  
    84  	for i := range name.subcommands {
    85  		c.AddCommand(a.newCmdFrom(name.subcommands[i], cc))
    86  	}
    87  
    88  	return c
    89  }
    90  
    91  func (a *App) bindCommand(c *cobra.Command, v any, n *Name) {
    92  	rv := reflect.ValueOf(v)
    93  
    94  	if rv.Kind() != reflect.Ptr || rv.Elem().Kind() != reflect.Struct {
    95  		panic(fmt.Errorf("only support a ptr struct value, but got %#v", v))
    96  	}
    97  
    98  	rv = rv.Elem()
    99  
   100  	if n.Name == "" {
   101  		n.Name = strings.ToLower(rv.Type().Name())
   102  	}
   103  
   104  	a.bindFromReflectValue(c, rv)
   105  
   106  	c.Use = fmt.Sprintf("%s [flags] %s", n.Name, strings.Join(n.ValidArgs, " "))
   107  	c.Short = n.Desc
   108  }
   109  
   110  func (a *App) bindFromReflectValue(c *cobra.Command, rv reflect.Value) {
   111  	t := rv.Type()
   112  
   113  	for i := 0; i < t.NumField(); i++ {
   114  		ft := t.Field(i)
   115  		fv := rv.Field(i)
   116  
   117  		if ft.Anonymous && ft.Type.Kind() == reflect.Struct {
   118  			if n, ok := fv.Addr().Interface().(*Name); ok {
   119  				n.Desc = ft.Tag.Get("desc")
   120  				if v, ok := ft.Tag.Lookup("args"); ok {
   121  					n.ValidArgs = ParseValidArgs(v)
   122  				}
   123  				continue
   124  			}
   125  			a.bindFromReflectValue(c, fv)
   126  			continue
   127  		}
   128  
   129  		if n, ok := ft.Tag.Lookup("flag"); ok {
   130  			parts := strings.SplitN(n, ",", 2)
   131  
   132  			name, alias := parts[0], strings.Join(parts[1:], "")
   133  
   134  			persistent := false
   135  
   136  			if len(name) > 0 && name[0] == '!' {
   137  				persistent = true
   138  				name = name[1:]
   139  			}
   140  
   141  			var envVars []string
   142  			if tagEnv, ok := ft.Tag.Lookup("env"); ok {
   143  				envVars = strings.Split(tagEnv, ",")
   144  			}
   145  
   146  			defaultText, defaultExists := ft.Tag.Lookup("default")
   147  
   148  			ff := &flagVar{
   149  				Name:        name,
   150  				Alias:       alias,
   151  				EnvVars:     envVars,
   152  				Default:     defaultText,
   153  				Required:    !defaultExists,
   154  				Desc:        ft.Tag.Get("desc"),
   155  				Destination: fv.Addr().Interface(),
   156  			}
   157  
   158  			if persistent {
   159  				if err := ff.Apply(c.PersistentFlags()); err != nil {
   160  					panic(err)
   161  				}
   162  			} else {
   163  				if err := ff.Apply(c.Flags()); err != nil {
   164  					panic(err)
   165  				}
   166  			}
   167  
   168  		}
   169  	}
   170  }
   171  
   172  type flagVar struct {
   173  	Name        string
   174  	Desc        string
   175  	Default     string
   176  	Required    bool
   177  	Alias       string
   178  	EnvVars     []string
   179  	Destination any
   180  }
   181  
   182  func (f *flagVar) DefaultValue() string {
   183  	v := f.Default
   184  	for i := range f.EnvVars {
   185  		if found, ok := os.LookupEnv(f.EnvVars[i]); ok {
   186  			v = found
   187  			break
   188  		}
   189  	}
   190  	return v
   191  }
   192  
   193  func (f *flagVar) Usage() string {
   194  	if len(f.EnvVars) > 0 {
   195  		s := strings.Builder{}
   196  		s.WriteString(f.Desc)
   197  		s.WriteString(" [")
   198  
   199  		for i, envVar := range f.EnvVars {
   200  			if i > 0 {
   201  				s.WriteString(",")
   202  			}
   203  			s.WriteString("$")
   204  			s.WriteString(envVar)
   205  		}
   206  
   207  		s.WriteString("]")
   208  		return s.String()
   209  	}
   210  	return f.Desc
   211  }
   212  
   213  func (f *flagVar) Apply(flags *pflag.FlagSet) error {
   214  	switch d := f.Destination.(type) {
   215  	case *[]string:
   216  		var v []string
   217  		if sv := f.DefaultValue(); sv != "" {
   218  			v = strings.Split(sv, ",")
   219  		}
   220  		flags.StringSliceVarP(d, f.Name, f.Alias, v, f.Usage())
   221  	case *string:
   222  		v := f.DefaultValue()
   223  		flags.StringVarP(d, f.Name, f.Alias, v, f.Usage())
   224  	case *int:
   225  		var v int
   226  		if sv := f.DefaultValue(); sv != "" {
   227  			v, _ = strconv.Atoi(sv)
   228  		}
   229  		flags.IntVarP(d, f.Name, f.Alias, v, f.Usage())
   230  	case *bool:
   231  		var v bool
   232  		if sv := f.DefaultValue(); sv != "" {
   233  			v, _ = strconv.ParseBool(sv)
   234  		}
   235  		flags.BoolVarP(d, f.Name, f.Alias, v, f.Usage())
   236  	}
   237  	return nil
   238  }