github.com/lyft/flytestdlib@v0.3.12-0.20210213045714-8cdd111ecda1/cli/pflags/api/pflag_provider.go (about) 1 package api 2 3 import ( 4 "bytes" 5 "fmt" 6 "go/types" 7 "io/ioutil" 8 "os" 9 "time" 10 11 "github.com/ernesto-jimenez/gogen/imports" 12 goimports "golang.org/x/tools/imports" 13 ) 14 15 type PFlagProvider struct { 16 typeName string 17 pkg *types.Package 18 fields []FieldInfo 19 } 20 21 // Adds any needed imports for types not directly declared in this package. 22 func (p PFlagProvider) Imports() map[string]string { 23 imp := imports.New(p.pkg.Name()) 24 for _, m := range p.fields { 25 imp.AddImportsFrom(m.Typ) 26 } 27 28 return imp.Imports() 29 } 30 31 // Evaluates the main code file template and writes the output to outputFilePath 32 func (p PFlagProvider) WriteCodeFile(outputFilePath string) error { 33 buf := bytes.Buffer{} 34 err := p.generate(GenerateCodeFile, &buf, outputFilePath) 35 if err != nil { 36 return fmt.Errorf("error generating code, Error: %v. Source: %v", err, buf.String()) 37 } 38 39 return p.writeToFile(&buf, outputFilePath) 40 } 41 42 // Evaluates the test code file template and writes the output to outputFilePath 43 func (p PFlagProvider) WriteTestFile(outputFilePath string) error { 44 buf := bytes.Buffer{} 45 err := p.generate(GenerateTestFile, &buf, outputFilePath) 46 if err != nil { 47 return fmt.Errorf("error generating code, Error: %v. Source: %v", err, buf.String()) 48 } 49 50 return p.writeToFile(&buf, outputFilePath) 51 } 52 53 func (p PFlagProvider) writeToFile(buffer *bytes.Buffer, fileName string) error { 54 return ioutil.WriteFile(fileName, buffer.Bytes(), os.ModePerm) 55 } 56 57 // Evaluates the generator and writes the output to buffer. targetFileName is used only to influence how imports are 58 // generated/optimized. 59 func (p PFlagProvider) generate(generator func(buffer *bytes.Buffer, info TypeInfo) error, buffer *bytes.Buffer, targetFileName string) error { 60 info := TypeInfo{ 61 Name: p.typeName, 62 Fields: p.fields, 63 Package: p.pkg.Name(), 64 Timestamp: time.Now(), 65 Imports: p.Imports(), 66 } 67 68 if err := generator(buffer, info); err != nil { 69 return err 70 } 71 72 // Update imports 73 newBytes, err := goimports.Process(targetFileName, buffer.Bytes(), nil) 74 if err != nil { 75 return err 76 } 77 78 buffer.Reset() 79 _, err = buffer.Write(newBytes) 80 81 return err 82 } 83 84 func newPflagProvider(pkg *types.Package, typeName string, fields []FieldInfo) PFlagProvider { 85 return PFlagProvider{ 86 typeName: typeName, 87 pkg: pkg, 88 fields: fields, 89 } 90 }