github.com/oNaiPs/go-generate-fast@v0.3.0/src/plugins/protoc/protoc.go (about)

     1  package plugin_protoc
     2  
     3  import (
     4  	"bufio"
     5  	"os"
     6  	"path"
     7  	"path/filepath"
     8  	"regexp"
     9  	"strings"
    10  
    11  	"github.com/jessevdk/go-flags"
    12  	"github.com/oNaiPs/go-generate-fast/src/plugins"
    13  	"go.uber.org/zap"
    14  )
    15  
    16  type ProtocPlugin struct {
    17  	plugins.Plugin
    18  }
    19  
    20  func (p *ProtocPlugin) Name() string {
    21  	return "protoc"
    22  }
    23  
    24  func (p *ProtocPlugin) Matches(opts plugins.GenerateOpts) bool {
    25  	return opts.ExecutableName == "protoc"
    26  }
    27  
    28  type ProtocParsedFlags struct {
    29  	Include []string `short:"I" long:"proto_path"`
    30  	GoOut   string   `long:"go_out"`
    31  	GoOpt   []string `long:"go_opt"`
    32  }
    33  
    34  func (p *ProtocPlugin) ComputeInputOutputFiles(opts plugins.GenerateOpts) *plugins.InputOutputFiles {
    35  	parsedFlags := ProtocParsedFlags{}
    36  	args, err := flags.ParseArgs(&parsedFlags, opts.SanitizedArgs)
    37  	if len(parsedFlags.Include) == 0 {
    38  		// default search path when no include paths are specified is current dir
    39  		parsedFlags.Include = append(parsedFlags.Include, opts.Dir())
    40  	}
    41  
    42  	if err != nil {
    43  		panic(err)
    44  	}
    45  
    46  	ioFiles := plugins.InputOutputFiles{}
    47  
    48  	pathsMode := "import"
    49  	for _, opt := range parsedFlags.GoOpt {
    50  		if strings.HasPrefix(opt, "paths") {
    51  			pathsMode = strings.TrimPrefix(opt, "paths=")
    52  		}
    53  	}
    54  
    55  	for index, inputFile := range args {
    56  		if !strings.HasSuffix(inputFile, ".proto") {
    57  			zap.S().Warn("Input file doesn't end with .proto: ", inputFile, ". Skipping")
    58  			continue
    59  		}
    60  
    61  		var inputFilePath string
    62  		if !filepath.IsAbs(inputFile) {
    63  			inputFilePath = searchFile(inputFile, parsedFlags.Include, opts.Dir())
    64  			if inputFilePath == "" {
    65  				zap.S().Warn("Cannot find ", inputFile, " in include dirs. Skipping")
    66  				continue
    67  			}
    68  			inputFile = filepath.Base(inputFilePath)
    69  		} else {
    70  			inputFilePath = inputFile
    71  			inputFile = filepath.Base(inputFile)
    72  		}
    73  		inputFileDir := filepath.Dir(inputFilePath)
    74  
    75  		ioFiles.InputFiles = append(ioFiles.InputFiles, inputFilePath)
    76  
    77  		protoFile, err := parseProtoFile(inputFilePath)
    78  		if err != nil {
    79  			zap.S().Warn("Cannot parse proto file: ", inputFilePath, ". Skipping")
    80  			continue
    81  		}
    82  		outputDir := ""
    83  		switch pathsMode {
    84  		// uses go_package specified in the .proto file
    85  		case "import":
    86  			goPackage := protoFile.GoPackage
    87  
    88  			//parse M go_opt (see https://protobuf.dev/reference/go/go-generated/#package)
    89  			for _, opt := range parsedFlags.GoOpt {
    90  				if !strings.HasPrefix(opt, "M") {
    91  					continue
    92  				}
    93  
    94  				spl := strings.Split(strings.TrimPrefix(opt, "M"), "=")
    95  
    96  				if spl[0] == args[index] {
    97  					goPackage = spl[1]
    98  				}
    99  			}
   100  
   101  			//remove possible package_name specification (not supported atm)
   102  			goPackage = strings.Split(goPackage, ";")[0]
   103  
   104  			outputDir = filepath.Join(opts.Dir(), goPackage)
   105  
   106  		case "source_relative":
   107  			outputDir = inputFileDir
   108  		default:
   109  			zap.S().Fatal("Unknown paths mode ", pathsMode)
   110  		}
   111  
   112  		outputFile := path.Join(outputDir, strings.TrimSuffix(inputFile, ".proto")+".pb.go")
   113  		ioFiles.OutputFiles = append(ioFiles.OutputFiles, outputFile)
   114  
   115  		for _, importFile := range protoFile.Imports {
   116  			importFilePath := searchFile(importFile, parsedFlags.Include, inputFileDir)
   117  			if importFilePath != "" {
   118  				ioFiles.InputFiles = append(ioFiles.InputFiles, importFilePath)
   119  			} else {
   120  				zap.S().Warn("Cannot find import ", importFile, " in include dirs. Skipping")
   121  			}
   122  		}
   123  	}
   124  
   125  	// TODO add @<filename>
   126  
   127  	return &ioFiles
   128  }
   129  
   130  // searchFile searches for a given file within a list of include directories.
   131  // It returns the full path of the file if found in any of the directories.
   132  // If the file is not found, an empty string is returned.
   133  func searchFile(filePath string, includeDirs []string, baseDir string) string {
   134  	for _, includeDir := range includeDirs {
   135  		if !filepath.IsAbs(includeDir) {
   136  			includeDir = filepath.Join(baseDir, includeDir)
   137  		}
   138  		fullPath := filepath.Join(includeDir, filePath)
   139  		if _, err := os.Stat(fullPath); err == nil {
   140  			return fullPath
   141  		}
   142  	}
   143  	return ""
   144  }
   145  
   146  type ProtoFile struct {
   147  	Imports   []string
   148  	GoPackage string
   149  }
   150  
   151  func parseProtoFile(filePath string) (*ProtoFile, error) {
   152  	file, err := os.Open(filePath)
   153  	if err != nil {
   154  		return nil, err
   155  	}
   156  	defer file.Close()
   157  
   158  	goFile := &ProtoFile{}
   159  	scanner := bufio.NewScanner(file)
   160  
   161  	for scanner.Scan() {
   162  		line := scanner.Text()
   163  
   164  		importMatch := regexp.MustCompile(`^\s*import\s+"(.*)"`).FindStringSubmatch(line)
   165  		if len(importMatch) > 0 {
   166  			goFile.Imports = append(goFile.Imports, importMatch[1])
   167  		}
   168  
   169  		goPackageMatch := regexp.MustCompile(`^\s*option\s+go_package\s*=\s*"(.*)"`).FindStringSubmatch(line)
   170  		if len(goPackageMatch) > 0 {
   171  			goFile.GoPackage = goPackageMatch[1]
   172  		}
   173  	}
   174  
   175  	if err := scanner.Err(); err != nil {
   176  		return nil, err
   177  	}
   178  
   179  	return goFile, nil
   180  }
   181  
   182  func init() {
   183  	plugins.RegisterPlugin(&ProtocPlugin{})
   184  }