github.com/lyft/flytestdlib@v0.3.12-0.20210213045714-8cdd111ecda1/cli/pflags/api/pflag_provider.go (about)

     1  package api
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"go/types"
     7  	"io/ioutil"
     8  	"os"
     9  	"time"
    10  
    11  	"github.com/ernesto-jimenez/gogen/imports"
    12  	goimports "golang.org/x/tools/imports"
    13  )
    14  
    15  type PFlagProvider struct {
    16  	typeName string
    17  	pkg      *types.Package
    18  	fields   []FieldInfo
    19  }
    20  
    21  // Adds any needed imports for types not directly declared in this package.
    22  func (p PFlagProvider) Imports() map[string]string {
    23  	imp := imports.New(p.pkg.Name())
    24  	for _, m := range p.fields {
    25  		imp.AddImportsFrom(m.Typ)
    26  	}
    27  
    28  	return imp.Imports()
    29  }
    30  
    31  // Evaluates the main code file template and writes the output to outputFilePath
    32  func (p PFlagProvider) WriteCodeFile(outputFilePath string) error {
    33  	buf := bytes.Buffer{}
    34  	err := p.generate(GenerateCodeFile, &buf, outputFilePath)
    35  	if err != nil {
    36  		return fmt.Errorf("error generating code, Error: %v. Source: %v", err, buf.String())
    37  	}
    38  
    39  	return p.writeToFile(&buf, outputFilePath)
    40  }
    41  
    42  // Evaluates the test code file template and writes the output to outputFilePath
    43  func (p PFlagProvider) WriteTestFile(outputFilePath string) error {
    44  	buf := bytes.Buffer{}
    45  	err := p.generate(GenerateTestFile, &buf, outputFilePath)
    46  	if err != nil {
    47  		return fmt.Errorf("error generating code, Error: %v. Source: %v", err, buf.String())
    48  	}
    49  
    50  	return p.writeToFile(&buf, outputFilePath)
    51  }
    52  
    53  func (p PFlagProvider) writeToFile(buffer *bytes.Buffer, fileName string) error {
    54  	return ioutil.WriteFile(fileName, buffer.Bytes(), os.ModePerm)
    55  }
    56  
    57  // Evaluates the generator and writes the output to buffer. targetFileName is used only to influence how imports are
    58  // generated/optimized.
    59  func (p PFlagProvider) generate(generator func(buffer *bytes.Buffer, info TypeInfo) error, buffer *bytes.Buffer, targetFileName string) error {
    60  	info := TypeInfo{
    61  		Name:      p.typeName,
    62  		Fields:    p.fields,
    63  		Package:   p.pkg.Name(),
    64  		Timestamp: time.Now(),
    65  		Imports:   p.Imports(),
    66  	}
    67  
    68  	if err := generator(buffer, info); err != nil {
    69  		return err
    70  	}
    71  
    72  	// Update imports
    73  	newBytes, err := goimports.Process(targetFileName, buffer.Bytes(), nil)
    74  	if err != nil {
    75  		return err
    76  	}
    77  
    78  	buffer.Reset()
    79  	_, err = buffer.Write(newBytes)
    80  
    81  	return err
    82  }
    83  
    84  func newPflagProvider(pkg *types.Package, typeName string, fields []FieldInfo) PFlagProvider {
    85  	return PFlagProvider{
    86  		typeName: typeName,
    87  		pkg:      pkg,
    88  		fields:   fields,
    89  	}
    90  }