github.com/stackb/rules_proto@v0.0.0-20240221195024-5428336c51f1/pkg/protoc/file.go (about)

     1  package protoc
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"os"
     7  	"path"
     8  	"path/filepath"
     9  	"strings"
    10  	"unicode"
    11  
    12  	"github.com/emicklei/proto"
    13  )
    14  
    15  // NewFile takes the package directory and base name of the file (e.g.
    16  // 'foo.proto') and constructs File
    17  func NewFile(dir, basename string) *File {
    18  	return &File{
    19  		Dir:      dir,
    20  		Basename: basename,
    21  		Name:     strings.TrimSuffix(basename, filepath.Ext(basename)),
    22  	}
    23  }
    24  
    25  // File represents a proto file that is discovered in a package.
    26  type File struct {
    27  	Dir      string // e.g. "path/to/package/"
    28  	Basename string // e.g. "foo.proto"
    29  	Name     string // e.g. "foo"
    30  
    31  	pkg         proto.Package
    32  	imports     []proto.Import
    33  	options     []proto.Option
    34  	services    []proto.Service
    35  	rpcs        []proto.RPC
    36  	messages    []proto.Message
    37  	enums       []proto.Enum
    38  	enumOptions []proto.Option
    39  	rpcOptions  []proto.Option
    40  }
    41  
    42  // Relname returns the relative path of the proto file.
    43  func (f *File) Relname() string {
    44  	if f.Dir == "" {
    45  		return f.Basename
    46  	}
    47  	return filepath.Join(f.Dir, f.Basename)
    48  }
    49  
    50  // Package returns the defined package or the empty value.
    51  func (f *File) Package() proto.Package {
    52  	return f.pkg
    53  }
    54  
    55  // Imports returns the list of Imports defined in the proto file.
    56  func (f *File) Imports() []proto.Import {
    57  	return f.imports
    58  }
    59  
    60  // Options returns the list of top-level options defined in the proto file.
    61  func (f *File) Options() []proto.Option {
    62  	return f.options
    63  }
    64  
    65  // Services returns the list of Services defined in the proto file.
    66  func (f *File) Services() []proto.Service {
    67  	return f.services
    68  }
    69  
    70  // Messages returns the list of Messages defined in the proto file.
    71  func (f *File) Messages() []proto.Message {
    72  	return f.messages
    73  }
    74  
    75  // Enums returns the list of Enums defined in the proto file.
    76  func (f *File) Enums() []proto.Enum {
    77  	return f.enums
    78  }
    79  
    80  // EnumOptions returns the list of EnumOptions defined in the proto file.
    81  func (f *File) EnumOptions() []proto.Option {
    82  	return f.enumOptions
    83  }
    84  
    85  // HasEnums returns true if the proto file has at least one enum.
    86  func (f *File) HasEnums() bool {
    87  	return len(f.enums) > 0
    88  }
    89  
    90  // HasMessages returns true if the proto file has at least one message.
    91  func (f *File) HasMessages() bool {
    92  	return len(f.messages) > 0
    93  }
    94  
    95  // HasServices returns true if the proto file has at least one service.
    96  func (f *File) HasServices() bool {
    97  	return len(f.services) > 0
    98  }
    99  
   100  // HasRPCs returns true if the proto file has at least one service.
   101  func (f *File) HasRPCs() bool {
   102  	return len(f.rpcs) > 0
   103  }
   104  
   105  // HasEnumOption returns true if the proto file has at least one enum or enum
   106  // field annotated with the given named field extension.
   107  func (f *File) HasEnumOption(name string) bool {
   108  	for _, option := range f.enumOptions {
   109  		if option.Name == name {
   110  			return true
   111  		}
   112  	}
   113  	return false
   114  }
   115  
   116  // HasRPCOption returns true if the proto file has at least one rpc annotated
   117  // with the given named field extension.
   118  func (f *File) HasRPCOption(name string) bool {
   119  	for _, option := range f.rpcOptions {
   120  		if option.Name == name {
   121  			return true
   122  		}
   123  	}
   124  	return false
   125  }
   126  
   127  // Parse reads the proto file and parses the source.
   128  func (f *File) Parse() error {
   129  	wd, err := os.Getwd()
   130  	if err != nil {
   131  		return fmt.Errorf("could not parse: %v", err)
   132  	}
   133  
   134  	if bwd, ok := os.LookupEnv("BUILD_WORKSPACE_DIRECTORY"); ok {
   135  		wd = bwd
   136  	}
   137  
   138  	filename := filepath.Join(wd, f.Dir, f.Basename)
   139  	reader, err := os.Open(filename)
   140  	if err != nil {
   141  		return fmt.Errorf("could not open %s: %w (cwd=%s)", filename, err, wd)
   142  	}
   143  	defer reader.Close()
   144  
   145  	return f.ParseReader(reader)
   146  }
   147  
   148  // ParseReader parses the reader and walks statements in the file.
   149  func (f *File) ParseReader(in io.Reader) error {
   150  	parser := proto.NewParser(in)
   151  	definition, err := parser.Parse()
   152  	if err != nil {
   153  		return fmt.Errorf("could not parse %s/%s: %w", f.Dir, f.Basename, err)
   154  	}
   155  
   156  	proto.Walk(definition,
   157  		proto.WithPackage(f.handlePackage),
   158  		proto.WithOption(f.handleOption),
   159  		proto.WithImport(f.handleImport),
   160  		proto.WithService(f.handleService),
   161  		proto.WithRPC(f.handleRPC),
   162  		proto.WithMessage(f.handleMessage),
   163  		proto.WithEnum(f.handleEnum))
   164  
   165  	return nil
   166  }
   167  
   168  func (f *File) handlePackage(p *proto.Package) {
   169  	f.pkg = *p
   170  }
   171  
   172  type optionParentVisitor struct {
   173  	proto.NoopVisitor
   174  	visitedRPC       bool
   175  	visitedEnum      bool
   176  	visitedEnumField bool
   177  }
   178  
   179  func (v *optionParentVisitor) VisitRPC(r *proto.RPC) {
   180  	v.visitedRPC = true
   181  }
   182  
   183  func (v *optionParentVisitor) VisitEnum(e *proto.Enum) {
   184  	v.visitedEnum = true
   185  }
   186  
   187  func (v *optionParentVisitor) VisitEnumField(f *proto.EnumField) {
   188  	v.visitedEnumField = true
   189  }
   190  
   191  func (f *File) handleOption(o *proto.Option) {
   192  	f.options = append(f.options, *o)
   193  	var parentVisitor optionParentVisitor
   194  	o.Parent.Accept(&parentVisitor)
   195  	if parentVisitor.visitedEnum || parentVisitor.visitedEnumField {
   196  		f.enumOptions = append(f.enumOptions, *o)
   197  	}
   198  	if parentVisitor.visitedRPC {
   199  		f.rpcOptions = append(f.rpcOptions, *o)
   200  	}
   201  }
   202  
   203  func (f *File) handleImport(i *proto.Import) {
   204  	f.imports = append(f.imports, *i)
   205  }
   206  
   207  func (f *File) handleEnum(i *proto.Enum) {
   208  	f.enums = append(f.enums, *i)
   209  }
   210  
   211  func (f *File) handleService(s *proto.Service) {
   212  	f.services = append(f.services, *s)
   213  }
   214  
   215  func (f *File) handleRPC(r *proto.RPC) {
   216  	f.rpcs = append(f.rpcs, *r)
   217  }
   218  
   219  func (f *File) handleMessage(m *proto.Message) {
   220  	f.messages = append(f.messages, *m)
   221  }
   222  
   223  // PackageFileNameWithExtensions returns a function that computes the name of a
   224  // predicted generated file having the given extension(s).  If the proto package
   225  // is defined, the output file will be in the corresponding directory.
   226  func PackageFileNameWithExtensions(exts ...string) func(f *File) []string {
   227  	return func(f *File) []string {
   228  		outs := make([]string, len(exts))
   229  		name := f.Name
   230  		pkg := f.Package()
   231  		if pkg.Name != "" {
   232  			name = path.Join(strings.ReplaceAll(pkg.Name, ".", "/"), name)
   233  		}
   234  		for i, ext := range exts {
   235  			outs[i] = name + ext
   236  		}
   237  		return outs
   238  	}
   239  }
   240  
   241  // RelativeFileNameWithExtensions returns a function that computes the name of a
   242  // predicted generated file having the given extension(s) relative to the given
   243  // dir.
   244  func RelativeFileNameWithExtensions(reldir string, exts ...string) func(f *File) []string {
   245  	return func(f *File) []string {
   246  		outs := make([]string, len(exts))
   247  		name := f.Name
   248  		if reldir != "" {
   249  			name = path.Join(reldir, name)
   250  		}
   251  		for i, ext := range exts {
   252  			outs[i] = name + ext
   253  		}
   254  		return outs
   255  	}
   256  }
   257  
   258  // ImportPrefixRelativeFileNameWithExtensions returns a function that computes
   259  // the name of a predicted generated file. In this case, first
   260  // RelativeFileNameWithExtensions is applied, then stripImportPrefix is removed
   261  // from the predicted filename.
   262  func ImportPrefixRelativeFileNameWithExtensions(stripImportPrefix, reldir string, exts ...string) func(f *File) []string {
   263  	// if the stripImportPrefix is defined and "absolute" (starting with a
   264  	// slash), this means it is relative to the repository root.
   265  	// https://github.com/bazelbuild/bazel/issues/3867#issuecomment-441971525
   266  	prefix := strings.TrimPrefix(stripImportPrefix, "/")
   267  	relfunc := RelativeFileNameWithExtensions(reldir, exts...)
   268  	return func(f *File) []string {
   269  		outs := relfunc(f)
   270  		for i, out := range outs {
   271  			if strings.HasPrefix(out, prefix) {
   272  				outs[i] = strings.TrimPrefix(out[len(prefix):], "/")
   273  			}
   274  		}
   275  		return outs
   276  	}
   277  }
   278  
   279  // HasMessagesOrEnums checks if any of the given files has a message or an enum.
   280  func HasMessagesOrEnums(files ...*File) bool {
   281  	for _, f := range files {
   282  		if HasMessageOrEnum(f) {
   283  			return true
   284  		}
   285  	}
   286  	return false
   287  }
   288  
   289  // HasServices checks if any of the given files has a service.
   290  func HasServices(files ...*File) bool {
   291  	for _, f := range files {
   292  		if HasService(f) {
   293  			return true
   294  		}
   295  	}
   296  	return false
   297  }
   298  
   299  // HasMessageOrEnum is a file predicate function checks if any of the given file
   300  // has a message or an enum.
   301  func HasMessageOrEnum(file *File) bool {
   302  	return file.HasMessages() || file.HasEnums()
   303  }
   304  
   305  // Always is a file predicate function that always returns true.
   306  func Always(file *File) bool {
   307  	return true
   308  }
   309  
   310  // HasService is a file predicate function that tests if any of the given file
   311  // has a message or an enum.
   312  func HasService(file *File) bool {
   313  	return file.HasServices()
   314  }
   315  
   316  // FlatMapFiles is a utility function intended for use in computing a list of
   317  // output files for a given proto_library. The given apply function is executed
   318  // foreach file that passes the filter function, and flattens the strings into a
   319  // single list.
   320  func FlatMapFiles(apply func(file *File) []string, filter func(file *File) bool, files ...*File) []string {
   321  	values := make([]string, 0)
   322  	for _, f := range files {
   323  		if !filter(f) {
   324  			continue
   325  		}
   326  		values = append(values, apply(f)...)
   327  	}
   328  	return values
   329  }
   330  
   331  // GoPackagePath replaces dots with forward slashes.
   332  func GoPackagePath(pkg string) string {
   333  	return strings.ReplaceAll(pkg, ".", "/")
   334  }
   335  
   336  // IsProtoFile returns true if the file extension looks like it should contain
   337  // protobuf definitions.
   338  func IsProtoFile(filename string) bool {
   339  	ext := filepath.Ext(filename)
   340  	return ext == ".proto" || ext == ".protodevel"
   341  }
   342  
   343  // GoPackageOption is a utility function to seek for the go_package option and
   344  // split it.  If present the return values will be populated with the importpath
   345  // and alias (e.g. github.com/foo/bar/v1;bar -> "github.com/foo/bar/v1", "bar").
   346  // If the option was not found the bool return argument is false.
   347  func GoPackageOption(options []proto.Option) (string, string, bool) {
   348  	for _, opt := range options {
   349  		if opt.Name != "go_package" {
   350  			continue
   351  		}
   352  		parts := strings.SplitN(opt.Constant.Source, ";", 2)
   353  		switch len(parts) {
   354  		case 0:
   355  			return "", "", true
   356  		case 1:
   357  			return parts[0], "", true
   358  		case 2:
   359  			return parts[0], parts[1], true
   360  		default:
   361  			return parts[0], strings.Join(parts[1:], ";"), true
   362  		}
   363  	}
   364  
   365  	return "", "", false
   366  }
   367  
   368  // GetNamedOption returns the value of an option.  If the option is not found,
   369  // the bool return value is false.
   370  func GetNamedOption(options []proto.Option, name string) (string, bool) {
   371  	for _, opt := range options {
   372  		if opt.Name != name {
   373  			continue
   374  		}
   375  		return opt.Constant.Source, true
   376  	}
   377  	return "", false
   378  }
   379  
   380  // ToPascalCase converts a string to PascalCase.
   381  //
   382  // Splits on '-', '_', ' ', '\t', '\n', '\r'.
   383  // Uppercase letters will stay uppercase,
   384  func ToPascalCase(s string) string {
   385  	output := ""
   386  	var previous rune
   387  	for i, c := range strings.TrimSpace(s) {
   388  		if !isDelimiter(c) {
   389  			if i == 0 || isDelimiter(previous) || unicode.IsUpper(c) {
   390  				output += string(unicode.ToUpper(c))
   391  			} else {
   392  				output += string(unicode.ToLower(c))
   393  			}
   394  		}
   395  		previous = c
   396  	}
   397  	return output
   398  }
   399  
   400  func isDelimiter(r rune) bool {
   401  	return r == '.' || r == '-' || r == '_' || r == ' ' || r == '\t' || r == '\n' || r == '\r'
   402  }