github.com/lyft/flytestdlib@v0.3.12-0.20210213045714-8cdd111ecda1/cli/pflags/api/generator.go (about) 1 package api 2 3 import ( 4 "context" 5 "fmt" 6 "go/types" 7 "path/filepath" 8 "strings" 9 10 "github.com/lyft/flytestdlib/logger" 11 12 "golang.org/x/tools/go/packages" 13 14 "github.com/ernesto-jimenez/gogen/gogenutil" 15 ) 16 17 const ( 18 indent = " " 19 ) 20 21 // PFlagProviderGenerator parses and generates GetPFlagSet implementation to add PFlags for a given struct's fields. 22 type PFlagProviderGenerator struct { 23 pkg *types.Package 24 st *types.Named 25 defaultVar *types.Var 26 shouldBindDefaultVar bool 27 } 28 29 // This list is restricted because that's the only kinds viper parses out, otherwise it assumes strings. 30 // github.com/spf13/viper/viper.go:1016 31 var allowedKinds = []types.Type{ 32 types.Typ[types.Int], 33 types.Typ[types.Int8], 34 types.Typ[types.Int16], 35 types.Typ[types.Int32], 36 types.Typ[types.Int64], 37 types.Typ[types.Bool], 38 types.Typ[types.String], 39 } 40 41 type SliceOrArray interface { 42 Elem() types.Type 43 } 44 45 func capitalize(s string) string { 46 if s[0] >= 'a' && s[0] <= 'z' { 47 return string(s[0]-'a'+'A') + s[1:] 48 } 49 50 return s 51 } 52 53 func buildFieldForSlice(ctx context.Context, t SliceOrArray, name, goName, usage, defaultValue string, bindDefaultVar bool) (FieldInfo, error) { 54 strategy := SliceRaw 55 FlagMethodName := "StringSlice" 56 typ := types.NewSlice(types.Typ[types.String]) 57 emptyDefaultValue := `[]string{}` 58 if b, ok := t.Elem().(*types.Basic); !ok { 59 logger.Infof(ctx, "Elem of type [%v] is not a basic type. It must be json unmarshalable or generation will fail.", t.Elem()) 60 if !isJSONUnmarshaler(t.Elem()) { 61 return FieldInfo{}, 62 fmt.Errorf("slice of type [%v] is not supported. Only basic slices or slices of json-unmarshalable types are supported", 63 t.Elem().String()) 64 } 65 } else { 66 logger.Infof(ctx, "Elem of type [%v] is a basic type. Will use a pflag as a Slice.", b) 67 strategy = SliceJoined 68 FlagMethodName = fmt.Sprintf("%vSlice", capitalize(b.Name())) 69 typ = types.NewSlice(b) 70 emptyDefaultValue = fmt.Sprintf(`[]%v{}`, b.Name()) 71 } 72 73 testValue := defaultValue 74 if len(defaultValue) == 0 { 75 defaultValue = emptyDefaultValue 76 testValue = `"1,1"` 77 } 78 79 return FieldInfo{ 80 Name: name, 81 GoName: goName, 82 Typ: typ, 83 FlagMethodName: FlagMethodName, 84 DefaultValue: defaultValue, 85 UsageString: usage, 86 TestValue: testValue, 87 TestStrategy: strategy, 88 ShouldBindDefault: bindDefaultVar, 89 }, nil 90 } 91 92 // Appends field accessors using "." as the delimiter. 93 // e.g. appendAccessors("var1", "field1", "subField") will output "var1.field1.subField" 94 func appendAccessors(accessors ...string) string { 95 sb := strings.Builder{} 96 switch len(accessors) { 97 case 0: 98 return "" 99 case 1: 100 return accessors[0] 101 } 102 103 for _, s := range accessors { 104 if len(s) > 0 { 105 if sb.Len() > 0 { 106 if _, err := sb.WriteString("."); err != nil { 107 fmt.Printf("Failed to writeString, error: %v", err) 108 return "" 109 } 110 } 111 112 if _, err := sb.WriteString(s); err != nil { 113 fmt.Printf("Failed to writeString, error: %v", err) 114 return "" 115 } 116 } 117 } 118 119 return sb.String() 120 } 121 122 // Traverses fields in type and follows recursion tree to discover all fields. It stops when one of two conditions is 123 // met; encountered a basic type (e.g. string, int... etc.) or the field type implements UnmarshalJSON. 124 // If passed a non-empty defaultValueAccessor, it'll be used to fill in default values instead of any default value 125 // specified in pflag tag. 126 func discoverFieldsRecursive(ctx context.Context, typ *types.Named, defaultValueAccessor, fieldPath string, bindDefaultVar bool) ([]FieldInfo, error) { 127 logger.Printf(ctx, "Finding all fields in [%v.%v.%v]", 128 typ.Obj().Pkg().Path(), typ.Obj().Pkg().Name(), typ.Obj().Name()) 129 130 ctx = logger.WithIndent(ctx, indent) 131 132 st := typ.Underlying().(*types.Struct) 133 fields := make([]FieldInfo, 0, st.NumFields()) 134 for i := 0; i < st.NumFields(); i++ { 135 v := st.Field(i) 136 if !v.IsField() { 137 continue 138 } 139 140 // Parses out the tag if one exists. 141 tag, err := ParseTag(st.Tag(i)) 142 if err != nil { 143 return nil, err 144 } 145 146 if len(tag.Name) == 0 { 147 tag.Name = v.Name() 148 } 149 150 if tag.DefaultValue == "-" { 151 logger.Infof(ctx, "Skipping field [%s], as '-' value detected", tag.Name) 152 continue 153 } 154 155 typ := v.Type() 156 ptr, isPtr := typ.(*types.Pointer) 157 if isPtr { 158 typ = ptr.Elem() 159 } 160 161 switch t := typ.(type) { 162 case *types.Basic: 163 if len(tag.DefaultValue) == 0 { 164 tag.DefaultValue = fmt.Sprintf("*new(%v)", typ.String()) 165 } 166 167 logger.Infof(ctx, "[%v] is of a basic type with default value [%v].", tag.Name, tag.DefaultValue) 168 169 isAllowed := false 170 for _, k := range allowedKinds { 171 if t.String() == k.String() { 172 isAllowed = true 173 break 174 } 175 } 176 177 if !isAllowed { 178 return nil, fmt.Errorf("only these basic kinds are allowed. given [%v] (Kind: [%v]. expected: [%+v]", 179 t.String(), t.Kind(), allowedKinds) 180 } 181 182 defaultValue := tag.DefaultValue 183 if len(defaultValueAccessor) > 0 { 184 defaultValue = appendAccessors(defaultValueAccessor, fieldPath, v.Name()) 185 186 if isPtr { 187 defaultValue = fmt.Sprintf("%s.elemValueOrNil(%s).(%s)", defaultValueAccessor, defaultValue, t.Name()) 188 } 189 } 190 191 fields = append(fields, FieldInfo{ 192 Name: tag.Name, 193 GoName: v.Name(), 194 Typ: t, 195 FlagMethodName: camelCase(t.String()), 196 DefaultValue: defaultValue, 197 UsageString: tag.Usage, 198 TestValue: `"1"`, 199 TestStrategy: JSON, 200 ShouldBindDefault: bindDefaultVar, 201 }) 202 case *types.Named: 203 if _, isStruct := t.Underlying().(*types.Struct); !isStruct { 204 // TODO: Add a more descriptive error message. 205 return nil, fmt.Errorf("invalid type. it must be struct, received [%v] for field [%v]", t.Underlying().String(), tag.Name) 206 } 207 208 // If the type has json unmarshaler, then stop the recursion and assume the type is string. config package 209 // will use json unmarshaler to fill in the final config object. 210 jsonUnmarshaler := isJSONUnmarshaler(t) 211 212 defaultValue := tag.DefaultValue 213 if len(defaultValueAccessor) > 0 { 214 defaultValue = appendAccessors(defaultValueAccessor, fieldPath, v.Name()) 215 if isStringer(t) { 216 defaultValue = defaultValue + ".String()" 217 } else { 218 logger.Infof(ctx, "Field [%v] of type [%v] does not implement Stringer interface."+ 219 " Will use %s.mustMarshalJSON() to get its default value.", defaultValueAccessor, v.Name(), t.String()) 220 defaultValue = fmt.Sprintf("%s.mustMarshalJSON(%s)", defaultValueAccessor, defaultValue) 221 } 222 } 223 224 testValue := defaultValue 225 if len(testValue) == 0 { 226 testValue = `"1"` 227 } 228 229 logger.Infof(ctx, "[%v] is of a Named type (struct) with default value [%v].", tag.Name, tag.DefaultValue) 230 231 if jsonUnmarshaler { 232 logger.Infof(logger.WithIndent(ctx, indent), "Type is json unmarhslalable.") 233 234 fields = append(fields, FieldInfo{ 235 Name: tag.Name, 236 GoName: v.Name(), 237 Typ: types.Typ[types.String], 238 FlagMethodName: "String", 239 DefaultValue: defaultValue, 240 UsageString: tag.Usage, 241 TestValue: testValue, 242 TestStrategy: JSON, 243 ShouldBindDefault: bindDefaultVar, 244 }) 245 } else { 246 logger.Infof(ctx, "Traversing fields in type.") 247 248 nested, err := discoverFieldsRecursive(logger.WithIndent(ctx, indent), t, defaultValueAccessor, appendAccessors(fieldPath, v.Name()), bindDefaultVar) 249 if err != nil { 250 return nil, err 251 } 252 253 for _, subField := range nested { 254 fields = append(fields, FieldInfo{ 255 Name: fmt.Sprintf("%v.%v", tag.Name, subField.Name), 256 GoName: fmt.Sprintf("%v.%v", v.Name(), subField.GoName), 257 Typ: subField.Typ, 258 FlagMethodName: subField.FlagMethodName, 259 DefaultValue: subField.DefaultValue, 260 UsageString: subField.UsageString, 261 TestValue: subField.TestValue, 262 TestStrategy: subField.TestStrategy, 263 ShouldBindDefault: bindDefaultVar, 264 }) 265 } 266 } 267 case *types.Slice: 268 logger.Infof(ctx, "[%v] is of a slice type with default value [%v].", tag.Name, tag.DefaultValue) 269 270 f, err := buildFieldForSlice(logger.WithIndent(ctx, indent), t, tag.Name, v.Name(), tag.Usage, tag.DefaultValue, bindDefaultVar) 271 if err != nil { 272 return nil, err 273 } 274 275 fields = append(fields, f) 276 case *types.Array: 277 logger.Infof(ctx, "[%v] is of an array with default value [%v].", tag.Name, tag.DefaultValue) 278 279 f, err := buildFieldForSlice(logger.WithIndent(ctx, indent), t, tag.Name, v.Name(), tag.Usage, tag.DefaultValue, bindDefaultVar) 280 if err != nil { 281 return nil, err 282 } 283 284 fields = append(fields, f) 285 default: 286 return nil, fmt.Errorf("unexpected type %v", t.String()) 287 } 288 } 289 290 return fields, nil 291 } 292 293 // NewGenerator initializes a PFlagProviderGenerator for pflags files for targetTypeName struct under pkg. If pkg is not filled in, 294 // it's assumed to be current package (which is expected to be the common use case when invoking pflags from 295 // go:generate comments) 296 func NewGenerator(pkg, targetTypeName, defaultVariableName string, shouldBindDefaultVar bool) (*PFlagProviderGenerator, error) { 297 ctx := context.Background() 298 var err error 299 300 // Resolve package path 301 if pkg == "" || pkg[0] == '.' { 302 pkg, err = filepath.Abs(filepath.Clean(pkg)) 303 if err != nil { 304 return nil, err 305 } 306 307 pkg = gogenutil.StripGopath(pkg) 308 logger.InfofNoCtx("Loading package from path [%v]", pkg) 309 } 310 311 targetPackage, err := loadPackage(pkg) 312 if err != nil { 313 return nil, err 314 } 315 316 obj := targetPackage.Scope().Lookup(targetTypeName) 317 if obj == nil { 318 return nil, fmt.Errorf("struct %s missing", targetTypeName) 319 } 320 321 var st *types.Named 322 switch obj.Type().Underlying().(type) { 323 case *types.Struct: 324 st = obj.Type().(*types.Named) 325 default: 326 return nil, fmt.Errorf("%s should be an struct, was %s", targetTypeName, obj.Type().Underlying()) 327 } 328 329 var defaultVar *types.Var 330 obj = targetPackage.Scope().Lookup(defaultVariableName) 331 if obj != nil { 332 defaultVar = obj.(*types.Var) 333 } 334 335 if defaultVar != nil { 336 logger.Infof(ctx, "Using default variable with name [%v] to assign all default values.", defaultVariableName) 337 } else { 338 logger.Infof(ctx, "Using default values defined in tags if any.") 339 } 340 341 return &PFlagProviderGenerator{ 342 st: st, 343 pkg: targetPackage, 344 defaultVar: defaultVar, 345 shouldBindDefaultVar: shouldBindDefaultVar, 346 }, nil 347 } 348 349 func loadPackage(pkg string) (*types.Package, error) { 350 config := &packages.Config{ 351 Mode: packages.NeedTypes | packages.NeedTypesInfo, 352 Logf: logger.InfofNoCtx, 353 } 354 355 loadedPkgs, err := packages.Load(config, pkg) 356 if err != nil { 357 return nil, err 358 } 359 360 if len(loadedPkgs) == 0 { 361 return nil, fmt.Errorf("No packages loaded") 362 } 363 364 targetPackage := loadedPkgs[0].Types 365 return targetPackage, nil 366 } 367 368 func (g PFlagProviderGenerator) GetTargetPackage() *types.Package { 369 return g.pkg 370 } 371 372 func (g PFlagProviderGenerator) Generate(ctx context.Context) (PFlagProvider, error) { 373 defaultValueAccessor := "" 374 if g.defaultVar != nil { 375 defaultValueAccessor = g.defaultVar.Name() 376 } 377 378 fields, err := discoverFieldsRecursive(ctx, g.st, defaultValueAccessor, "", g.shouldBindDefaultVar) 379 if err != nil { 380 return PFlagProvider{}, err 381 } 382 383 return newPflagProvider(g.pkg, g.st.Obj().Name(), fields), nil 384 }