github.com/lyft/flytestdlib@v0.3.12-0.20210213045714-8cdd111ecda1/cli/pflags/api/generator.go (about)

     1  package api
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"go/types"
     7  	"path/filepath"
     8  	"strings"
     9  
    10  	"github.com/lyft/flytestdlib/logger"
    11  
    12  	"golang.org/x/tools/go/packages"
    13  
    14  	"github.com/ernesto-jimenez/gogen/gogenutil"
    15  )
    16  
    17  const (
    18  	indent = "  "
    19  )
    20  
    21  // PFlagProviderGenerator parses and generates GetPFlagSet implementation to add PFlags for a given struct's fields.
    22  type PFlagProviderGenerator struct {
    23  	pkg                  *types.Package
    24  	st                   *types.Named
    25  	defaultVar           *types.Var
    26  	shouldBindDefaultVar bool
    27  }
    28  
    29  // This list is restricted because that's the only kinds viper parses out, otherwise it assumes strings.
    30  // github.com/spf13/viper/viper.go:1016
    31  var allowedKinds = []types.Type{
    32  	types.Typ[types.Int],
    33  	types.Typ[types.Int8],
    34  	types.Typ[types.Int16],
    35  	types.Typ[types.Int32],
    36  	types.Typ[types.Int64],
    37  	types.Typ[types.Bool],
    38  	types.Typ[types.String],
    39  }
    40  
    41  type SliceOrArray interface {
    42  	Elem() types.Type
    43  }
    44  
    45  func capitalize(s string) string {
    46  	if s[0] >= 'a' && s[0] <= 'z' {
    47  		return string(s[0]-'a'+'A') + s[1:]
    48  	}
    49  
    50  	return s
    51  }
    52  
    53  func buildFieldForSlice(ctx context.Context, t SliceOrArray, name, goName, usage, defaultValue string, bindDefaultVar bool) (FieldInfo, error) {
    54  	strategy := SliceRaw
    55  	FlagMethodName := "StringSlice"
    56  	typ := types.NewSlice(types.Typ[types.String])
    57  	emptyDefaultValue := `[]string{}`
    58  	if b, ok := t.Elem().(*types.Basic); !ok {
    59  		logger.Infof(ctx, "Elem of type [%v] is not a basic type. It must be json unmarshalable or generation will fail.", t.Elem())
    60  		if !isJSONUnmarshaler(t.Elem()) {
    61  			return FieldInfo{},
    62  				fmt.Errorf("slice of type [%v] is not supported. Only basic slices or slices of json-unmarshalable types are supported",
    63  					t.Elem().String())
    64  		}
    65  	} else {
    66  		logger.Infof(ctx, "Elem of type [%v] is a basic type. Will use a pflag as a Slice.", b)
    67  		strategy = SliceJoined
    68  		FlagMethodName = fmt.Sprintf("%vSlice", capitalize(b.Name()))
    69  		typ = types.NewSlice(b)
    70  		emptyDefaultValue = fmt.Sprintf(`[]%v{}`, b.Name())
    71  	}
    72  
    73  	testValue := defaultValue
    74  	if len(defaultValue) == 0 {
    75  		defaultValue = emptyDefaultValue
    76  		testValue = `"1,1"`
    77  	}
    78  
    79  	return FieldInfo{
    80  		Name:              name,
    81  		GoName:            goName,
    82  		Typ:               typ,
    83  		FlagMethodName:    FlagMethodName,
    84  		DefaultValue:      defaultValue,
    85  		UsageString:       usage,
    86  		TestValue:         testValue,
    87  		TestStrategy:      strategy,
    88  		ShouldBindDefault: bindDefaultVar,
    89  	}, nil
    90  }
    91  
    92  // Appends field accessors using "." as the delimiter.
    93  // e.g. appendAccessors("var1", "field1", "subField") will output "var1.field1.subField"
    94  func appendAccessors(accessors ...string) string {
    95  	sb := strings.Builder{}
    96  	switch len(accessors) {
    97  	case 0:
    98  		return ""
    99  	case 1:
   100  		return accessors[0]
   101  	}
   102  
   103  	for _, s := range accessors {
   104  		if len(s) > 0 {
   105  			if sb.Len() > 0 {
   106  				if _, err := sb.WriteString("."); err != nil {
   107  					fmt.Printf("Failed to writeString, error: %v", err)
   108  					return ""
   109  				}
   110  			}
   111  
   112  			if _, err := sb.WriteString(s); err != nil {
   113  				fmt.Printf("Failed to writeString, error: %v", err)
   114  				return ""
   115  			}
   116  		}
   117  	}
   118  
   119  	return sb.String()
   120  }
   121  
   122  // Traverses fields in type and follows recursion tree to discover all fields. It stops when one of two conditions is
   123  // met; encountered a basic type (e.g. string, int... etc.) or the field type implements UnmarshalJSON.
   124  // If passed a non-empty defaultValueAccessor, it'll be used to fill in default values instead of any default value
   125  // specified in pflag tag.
   126  func discoverFieldsRecursive(ctx context.Context, typ *types.Named, defaultValueAccessor, fieldPath string, bindDefaultVar bool) ([]FieldInfo, error) {
   127  	logger.Printf(ctx, "Finding all fields in [%v.%v.%v]",
   128  		typ.Obj().Pkg().Path(), typ.Obj().Pkg().Name(), typ.Obj().Name())
   129  
   130  	ctx = logger.WithIndent(ctx, indent)
   131  
   132  	st := typ.Underlying().(*types.Struct)
   133  	fields := make([]FieldInfo, 0, st.NumFields())
   134  	for i := 0; i < st.NumFields(); i++ {
   135  		v := st.Field(i)
   136  		if !v.IsField() {
   137  			continue
   138  		}
   139  
   140  		// Parses out the tag if one exists.
   141  		tag, err := ParseTag(st.Tag(i))
   142  		if err != nil {
   143  			return nil, err
   144  		}
   145  
   146  		if len(tag.Name) == 0 {
   147  			tag.Name = v.Name()
   148  		}
   149  
   150  		if tag.DefaultValue == "-" {
   151  			logger.Infof(ctx, "Skipping field [%s], as '-' value detected", tag.Name)
   152  			continue
   153  		}
   154  
   155  		typ := v.Type()
   156  		ptr, isPtr := typ.(*types.Pointer)
   157  		if isPtr {
   158  			typ = ptr.Elem()
   159  		}
   160  
   161  		switch t := typ.(type) {
   162  		case *types.Basic:
   163  			if len(tag.DefaultValue) == 0 {
   164  				tag.DefaultValue = fmt.Sprintf("*new(%v)", typ.String())
   165  			}
   166  
   167  			logger.Infof(ctx, "[%v] is of a basic type with default value [%v].", tag.Name, tag.DefaultValue)
   168  
   169  			isAllowed := false
   170  			for _, k := range allowedKinds {
   171  				if t.String() == k.String() {
   172  					isAllowed = true
   173  					break
   174  				}
   175  			}
   176  
   177  			if !isAllowed {
   178  				return nil, fmt.Errorf("only these basic kinds are allowed. given [%v] (Kind: [%v]. expected: [%+v]",
   179  					t.String(), t.Kind(), allowedKinds)
   180  			}
   181  
   182  			defaultValue := tag.DefaultValue
   183  			if len(defaultValueAccessor) > 0 {
   184  				defaultValue = appendAccessors(defaultValueAccessor, fieldPath, v.Name())
   185  
   186  				if isPtr {
   187  					defaultValue = fmt.Sprintf("%s.elemValueOrNil(%s).(%s)", defaultValueAccessor, defaultValue, t.Name())
   188  				}
   189  			}
   190  
   191  			fields = append(fields, FieldInfo{
   192  				Name:              tag.Name,
   193  				GoName:            v.Name(),
   194  				Typ:               t,
   195  				FlagMethodName:    camelCase(t.String()),
   196  				DefaultValue:      defaultValue,
   197  				UsageString:       tag.Usage,
   198  				TestValue:         `"1"`,
   199  				TestStrategy:      JSON,
   200  				ShouldBindDefault: bindDefaultVar,
   201  			})
   202  		case *types.Named:
   203  			if _, isStruct := t.Underlying().(*types.Struct); !isStruct {
   204  				// TODO: Add a more descriptive error message.
   205  				return nil, fmt.Errorf("invalid type. it must be struct, received [%v] for field [%v]", t.Underlying().String(), tag.Name)
   206  			}
   207  
   208  			// If the type has json unmarshaler, then stop the recursion and assume the type is string. config package
   209  			// will use json unmarshaler to fill in the final config object.
   210  			jsonUnmarshaler := isJSONUnmarshaler(t)
   211  
   212  			defaultValue := tag.DefaultValue
   213  			if len(defaultValueAccessor) > 0 {
   214  				defaultValue = appendAccessors(defaultValueAccessor, fieldPath, v.Name())
   215  				if isStringer(t) {
   216  					defaultValue = defaultValue + ".String()"
   217  				} else {
   218  					logger.Infof(ctx, "Field [%v] of type [%v] does not implement Stringer interface."+
   219  						" Will use %s.mustMarshalJSON() to get its default value.", defaultValueAccessor, v.Name(), t.String())
   220  					defaultValue = fmt.Sprintf("%s.mustMarshalJSON(%s)", defaultValueAccessor, defaultValue)
   221  				}
   222  			}
   223  
   224  			testValue := defaultValue
   225  			if len(testValue) == 0 {
   226  				testValue = `"1"`
   227  			}
   228  
   229  			logger.Infof(ctx, "[%v] is of a Named type (struct) with default value [%v].", tag.Name, tag.DefaultValue)
   230  
   231  			if jsonUnmarshaler {
   232  				logger.Infof(logger.WithIndent(ctx, indent), "Type is json unmarhslalable.")
   233  
   234  				fields = append(fields, FieldInfo{
   235  					Name:              tag.Name,
   236  					GoName:            v.Name(),
   237  					Typ:               types.Typ[types.String],
   238  					FlagMethodName:    "String",
   239  					DefaultValue:      defaultValue,
   240  					UsageString:       tag.Usage,
   241  					TestValue:         testValue,
   242  					TestStrategy:      JSON,
   243  					ShouldBindDefault: bindDefaultVar,
   244  				})
   245  			} else {
   246  				logger.Infof(ctx, "Traversing fields in type.")
   247  
   248  				nested, err := discoverFieldsRecursive(logger.WithIndent(ctx, indent), t, defaultValueAccessor, appendAccessors(fieldPath, v.Name()), bindDefaultVar)
   249  				if err != nil {
   250  					return nil, err
   251  				}
   252  
   253  				for _, subField := range nested {
   254  					fields = append(fields, FieldInfo{
   255  						Name:              fmt.Sprintf("%v.%v", tag.Name, subField.Name),
   256  						GoName:            fmt.Sprintf("%v.%v", v.Name(), subField.GoName),
   257  						Typ:               subField.Typ,
   258  						FlagMethodName:    subField.FlagMethodName,
   259  						DefaultValue:      subField.DefaultValue,
   260  						UsageString:       subField.UsageString,
   261  						TestValue:         subField.TestValue,
   262  						TestStrategy:      subField.TestStrategy,
   263  						ShouldBindDefault: bindDefaultVar,
   264  					})
   265  				}
   266  			}
   267  		case *types.Slice:
   268  			logger.Infof(ctx, "[%v] is of a slice type with default value [%v].", tag.Name, tag.DefaultValue)
   269  
   270  			f, err := buildFieldForSlice(logger.WithIndent(ctx, indent), t, tag.Name, v.Name(), tag.Usage, tag.DefaultValue, bindDefaultVar)
   271  			if err != nil {
   272  				return nil, err
   273  			}
   274  
   275  			fields = append(fields, f)
   276  		case *types.Array:
   277  			logger.Infof(ctx, "[%v] is of an array with default value [%v].", tag.Name, tag.DefaultValue)
   278  
   279  			f, err := buildFieldForSlice(logger.WithIndent(ctx, indent), t, tag.Name, v.Name(), tag.Usage, tag.DefaultValue, bindDefaultVar)
   280  			if err != nil {
   281  				return nil, err
   282  			}
   283  
   284  			fields = append(fields, f)
   285  		default:
   286  			return nil, fmt.Errorf("unexpected type %v", t.String())
   287  		}
   288  	}
   289  
   290  	return fields, nil
   291  }
   292  
   293  // NewGenerator initializes a PFlagProviderGenerator for pflags files for targetTypeName struct under pkg. If pkg is not filled in,
   294  // it's assumed to be current package (which is expected to be the common use case when invoking pflags from
   295  // go:generate comments)
   296  func NewGenerator(pkg, targetTypeName, defaultVariableName string, shouldBindDefaultVar bool) (*PFlagProviderGenerator, error) {
   297  	ctx := context.Background()
   298  	var err error
   299  
   300  	// Resolve package path
   301  	if pkg == "" || pkg[0] == '.' {
   302  		pkg, err = filepath.Abs(filepath.Clean(pkg))
   303  		if err != nil {
   304  			return nil, err
   305  		}
   306  
   307  		pkg = gogenutil.StripGopath(pkg)
   308  		logger.InfofNoCtx("Loading package from path [%v]", pkg)
   309  	}
   310  
   311  	targetPackage, err := loadPackage(pkg)
   312  	if err != nil {
   313  		return nil, err
   314  	}
   315  
   316  	obj := targetPackage.Scope().Lookup(targetTypeName)
   317  	if obj == nil {
   318  		return nil, fmt.Errorf("struct %s missing", targetTypeName)
   319  	}
   320  
   321  	var st *types.Named
   322  	switch obj.Type().Underlying().(type) {
   323  	case *types.Struct:
   324  		st = obj.Type().(*types.Named)
   325  	default:
   326  		return nil, fmt.Errorf("%s should be an struct, was %s", targetTypeName, obj.Type().Underlying())
   327  	}
   328  
   329  	var defaultVar *types.Var
   330  	obj = targetPackage.Scope().Lookup(defaultVariableName)
   331  	if obj != nil {
   332  		defaultVar = obj.(*types.Var)
   333  	}
   334  
   335  	if defaultVar != nil {
   336  		logger.Infof(ctx, "Using default variable with name [%v] to assign all default values.", defaultVariableName)
   337  	} else {
   338  		logger.Infof(ctx, "Using default values defined in tags if any.")
   339  	}
   340  
   341  	return &PFlagProviderGenerator{
   342  		st:                   st,
   343  		pkg:                  targetPackage,
   344  		defaultVar:           defaultVar,
   345  		shouldBindDefaultVar: shouldBindDefaultVar,
   346  	}, nil
   347  }
   348  
   349  func loadPackage(pkg string) (*types.Package, error) {
   350  	config := &packages.Config{
   351  		Mode: packages.NeedTypes | packages.NeedTypesInfo,
   352  		Logf: logger.InfofNoCtx,
   353  	}
   354  
   355  	loadedPkgs, err := packages.Load(config, pkg)
   356  	if err != nil {
   357  		return nil, err
   358  	}
   359  
   360  	if len(loadedPkgs) == 0 {
   361  		return nil, fmt.Errorf("No packages loaded")
   362  	}
   363  
   364  	targetPackage := loadedPkgs[0].Types
   365  	return targetPackage, nil
   366  }
   367  
   368  func (g PFlagProviderGenerator) GetTargetPackage() *types.Package {
   369  	return g.pkg
   370  }
   371  
   372  func (g PFlagProviderGenerator) Generate(ctx context.Context) (PFlagProvider, error) {
   373  	defaultValueAccessor := ""
   374  	if g.defaultVar != nil {
   375  		defaultValueAccessor = g.defaultVar.Name()
   376  	}
   377  
   378  	fields, err := discoverFieldsRecursive(ctx, g.st, defaultValueAccessor, "", g.shouldBindDefaultVar)
   379  	if err != nil {
   380  		return PFlagProvider{}, err
   381  	}
   382  
   383  	return newPflagProvider(g.pkg, g.st.Obj().Name(), fields), nil
   384  }