github.com/stackb/rules_proto@v0.0.0-20240221195024-5428336c51f1/pkg/rule/rules_scala/scala_library.go (about)

     1  package rules_scala
     2  
     3  import (
     4  	"flag"
     5  	"fmt"
     6  	"log"
     7  	"strings"
     8  
     9  	"github.com/bazelbuild/bazel-gazelle/config"
    10  	"github.com/bazelbuild/bazel-gazelle/label"
    11  	"github.com/bazelbuild/bazel-gazelle/resolve"
    12  	"github.com/bazelbuild/bazel-gazelle/rule"
    13  	"github.com/bmatcuk/doublestar"
    14  	"github.com/emicklei/proto"
    15  
    16  	"github.com/stackb/rules_proto/pkg/plugin/akka/akka_grpc"
    17  	"github.com/stackb/rules_proto/pkg/plugin/scalapb/scalapb"
    18  	"github.com/stackb/rules_proto/pkg/protoc"
    19  )
    20  
    21  const (
    22  	GrpcscalaLibraryRuleName        = "grpc_scala_library"
    23  	ProtoscalaLibraryRuleName       = "proto_scala_library"
    24  	protoScalaLibraryRuleSuffix     = "_proto_scala_library"
    25  	grpcScalaLibraryRuleSuffix      = "_grpc_scala_library"
    26  	scalaPbPluginOptionsPrivateKey  = "_scalapb_plugin"
    27  	akkaGrpcPluginOptionsPrivateKey = "_akka_grpc_plugin"
    28  	scalapbOptionsName              = "(scalapb.options)"
    29  	scalapbFieldTypeName            = "(scalapb.field).type"
    30  	scalaLangName                   = "scala"
    31  )
    32  
    33  func init() {
    34  	protoc.Rules().MustRegisterRule("stackb:rules_proto:"+ProtoscalaLibraryRuleName,
    35  		&scalaLibrary{
    36  			kindName:        ProtoscalaLibraryRuleName,
    37  			ruleSuffix:      protoScalaLibraryRuleSuffix,
    38  			protoFileFilter: messageFiles,
    39  		})
    40  	protoc.Rules().MustRegisterRule("stackb:rules_proto:"+GrpcscalaLibraryRuleName,
    41  		&scalaLibrary{
    42  			kindName:        GrpcscalaLibraryRuleName,
    43  			ruleSuffix:      grpcScalaLibraryRuleSuffix,
    44  			protoFileFilter: serviceFiles,
    45  		})
    46  }
    47  
    48  // scalaLibrary implements LanguageRule for the 'proto_scala_library' rule from
    49  // @rules_proto.
    50  type scalaLibrary struct {
    51  	kindName        string
    52  	ruleSuffix      string
    53  	protoFileFilter func([]*protoc.File) []*protoc.File
    54  }
    55  
    56  // Name implements part of the LanguageRule interface.
    57  func (s *scalaLibrary) Name() string {
    58  	return s.kindName
    59  }
    60  
    61  // KindInfo implements part of the LanguageRule interface.
    62  func (s *scalaLibrary) KindInfo() rule.KindInfo {
    63  	return rule.KindInfo{
    64  		MergeableAttrs: map[string]bool{
    65  			"srcs":    true,
    66  			"exports": true,
    67  		},
    68  		NonEmptyAttrs: map[string]bool{
    69  			"srcs": true,
    70  		},
    71  		ResolveAttrs: map[string]bool{
    72  			"deps": true,
    73  		},
    74  	}
    75  }
    76  
    77  // LoadInfo implements part of the LanguageRule interface.
    78  func (s *scalaLibrary) LoadInfo() rule.LoadInfo {
    79  	return rule.LoadInfo{
    80  		Name:    fmt.Sprintf("@build_stack_rules_proto//rules/scala:%s.bzl", s.kindName),
    81  		Symbols: []string{s.kindName},
    82  	}
    83  }
    84  
    85  // ProvideRule implements part of the LanguageRule interface.
    86  func (s *scalaLibrary) ProvideRule(cfg *protoc.LanguageRuleConfig, pc *protoc.ProtocConfiguration) protoc.RuleProvider {
    87  	files := s.protoFileFilter(pc.Library.Files())
    88  	if len(files) == 0 {
    89  		return nil
    90  	}
    91  
    92  	options := parseScalaLibraryOptions(s.kindName, cfg.GetOptions())
    93  
    94  	// the list of output files
    95  	outputs := make([]string, 0)
    96  
    97  	if len(options.plugins) == 0 {
    98  		log.Printf("warning: the rule %s should have at least one plugin name for the --plugins option.  This informs the rule which plugin(s) outputs correspond to this library rule", s.Name())
    99  	}
   100  
   101  	for _, name := range options.plugins {
   102  		plugin := getPluginConfiguration(pc.Plugins, name)
   103  		if plugin == nil {
   104  			// TODO: warn here?
   105  			continue
   106  		}
   107  		outputs = append(outputs, plugin.Outputs...)
   108  	}
   109  
   110  	outputs = options.filterOutputs(outputs)
   111  
   112  	if len(outputs) == 0 {
   113  		return nil
   114  	}
   115  
   116  	return &scalaLibraryRule{
   117  		kindName:       s.kindName,
   118  		ruleNameSuffix: s.ruleSuffix,
   119  		options:        options,
   120  		outputs:        outputs,
   121  		ruleConfig:     cfg,
   122  		config:         pc,
   123  		files:          files,
   124  	}
   125  }
   126  
   127  // scalaLibraryRule implements RuleProvider for 'scala_library'-derived rules.
   128  type scalaLibraryRule struct {
   129  	kindName       string
   130  	ruleNameSuffix string
   131  	outputs        []string
   132  	config         *protoc.ProtocConfiguration
   133  	ruleConfig     *protoc.LanguageRuleConfig
   134  	options        *scalaLibraryOptions
   135  	files          []*protoc.File
   136  }
   137  
   138  // Kind implements part of the ruleProvider interface.
   139  func (s *scalaLibraryRule) Kind() string {
   140  	return s.kindName
   141  }
   142  
   143  // Name implements part of the ruleProvider interface.
   144  func (s *scalaLibraryRule) Name() string {
   145  	return s.config.Library.BaseName() + s.ruleNameSuffix
   146  }
   147  
   148  // Srcs computes the srcs list for the rule.
   149  func (s *scalaLibraryRule) Srcs() []string {
   150  	srcs := make([]string, 0)
   151  	for _, output := range s.outputs {
   152  		if strings.HasSuffix(output, ".srcjar") {
   153  			srcs = append(srcs, protoc.StripRel(s.config.Rel, output))
   154  		}
   155  	}
   156  	return srcs
   157  }
   158  
   159  // Deps computes the deps list for the rule.
   160  func (s *scalaLibraryRule) Deps() []string {
   161  	deps := s.ruleConfig.GetDeps()
   162  
   163  	for _, pluginConfig := range s.config.Plugins {
   164  		deps = append(deps, pluginConfig.Config.GetDeps()...)
   165  	}
   166  
   167  	return protoc.DeduplicateAndSort(deps)
   168  }
   169  
   170  // Visibility provides visibility labels.
   171  func (s *scalaLibraryRule) Visibility() []string {
   172  	return s.ruleConfig.GetVisibility()
   173  }
   174  
   175  // Rule implements part of the ruleProvider interface.
   176  func (s *scalaLibraryRule) Rule(otherGen ...*rule.Rule) *rule.Rule {
   177  	newRule := rule.NewRule(s.Kind(), s.Name())
   178  
   179  	newRule.SetAttr("srcs", s.Srcs())
   180  
   181  	deps := s.Deps()
   182  	if len(deps) > 0 {
   183  		newRule.SetAttr("deps", deps)
   184  	}
   185  
   186  	exports := s.ruleConfig.GetAttr("exports")
   187  	if len(exports) > 0 {
   188  		newRule.SetAttr("exports", exports)
   189  	}
   190  
   191  	visibility := s.Visibility()
   192  	if len(visibility) > 0 {
   193  		newRule.SetAttr("visibility", visibility)
   194  	}
   195  
   196  	// add any imports from proto options.  Example: option (scalapb.options) =
   197  	// {
   198  	//  import: "com.foo.Bar"
   199  	// };
   200  	//
   201  	// NOTE: we pass *all* files from the proto_library.  Although the
   202  	// fileFilter has reduced the set into grpc or non-grpc ones, in practice
   203  	// protoc-gen-scala only has the "grpc" option.  When OFF, it will produce a
   204  	// srcjar with only messages. When that is ON, the compiler will produce a
   205  	// srcjar with both messages and services.  There is no way to tell the
   206  	// compiler to generate ONLY services (and not messages).  Therefore, we
   207  	// need all dependencies in order to compile the messages.
   208  	scalaImports := getScalapbImports(s.config.Library.Files())
   209  	if len(scalaImports) > 0 {
   210  		newRule.SetPrivateAttr(config.GazelleImportsKey, scalaImports)
   211  	}
   212  
   213  	// set the override language such that deps of 'proto_scala_library' and
   214  	// 'grpc_scala_library' can resolve together (matches the value used by
   215  	// "Imports").
   216  	newRule.SetPrivateAttr(protoc.ResolverImpLangPrivateKey, "scala")
   217  
   218  	// add the scalapb plugin options as a private attr so we can inspect them
   219  	// during the .Imports() phase.  For example, akka 'server_power_apis'
   220  	// generates additional classes.
   221  	scalaPbPlugin := s.config.GetPluginConfiguration(scalapb.ScalaPBPluginName)
   222  	if scalaPbPlugin != nil {
   223  		newRule.SetPrivateAttr(scalaPbPluginOptionsPrivateKey, scalaPbPlugin.Options)
   224  	}
   225  	akkaGrpcPlugin := s.config.GetPluginConfiguration(akka_grpc.AkkaGrpcPluginName)
   226  	if akkaGrpcPlugin != nil {
   227  		newRule.SetPrivateAttr(akkaGrpcPluginOptionsPrivateKey, akkaGrpcPlugin.Options)
   228  	}
   229  
   230  	return newRule
   231  }
   232  
   233  // Imports implements part of the RuleProvider interface.
   234  func (s *scalaLibraryRule) Imports(c *config.Config, r *rule.Rule, file *rule.File) []resolve.ImportSpec {
   235  	// 1. provide generated scala class names for message and services for
   236  	// 'scala scala' deps.  This will allow a scala extension to resolve proto
   237  	// deps when they import scala proto class names.
   238  	pluginOptions := make(map[string]bool)
   239  	if scalaPbPluginOptions, ok := r.PrivateAttr(scalaPbPluginOptionsPrivateKey).([]string); ok {
   240  		for _, opt := range scalaPbPluginOptions {
   241  			pluginOptions[opt] = true
   242  		}
   243  	}
   244  	if akkaGrpcPluginOptions, ok := r.PrivateAttr(akkaGrpcPluginOptionsPrivateKey).([]string); ok {
   245  		for _, opt := range akkaGrpcPluginOptions {
   246  			pluginOptions[opt] = true
   247  		}
   248  	}
   249  	from := label.New("", file.Pkg, r.Name())
   250  
   251  	provideScalaImports(s.files, protoc.GlobalResolver(), from, pluginOptions)
   252  
   253  	// 2. create import specs for 'protobuf scala'.  This allows
   254  	// proto_scala_library and grpc_scala_library to resolve deps.
   255  	return protoc.ProtoFilesImportSpecsForKind("scala", s.files)
   256  }
   257  
   258  // Resolve implements part of the RuleProvider interface.
   259  func (s *scalaLibraryRule) Resolve(c *config.Config, ix *resolve.RuleIndex, r *rule.Rule, imports []string, from label.Label) {
   260  	imports = s.options.filterImports(imports)
   261  
   262  	resolveFn := protoc.ResolveDepsAttr("deps", !s.options.resolveWKTs)
   263  	resolveFn(c, ix, r, imports, from)
   264  
   265  	if unresolvedDeps, ok := r.PrivateAttr(protoc.UnresolvedDepsPrivateKey).(map[string]error); ok {
   266  		if from.Repo == c.RepoName {
   267  			from.Repo = ""
   268  		}
   269  		resolveScalaDeps(resolve.FindRuleWithOverride, ix.FindRulesByImportWithConfig, c, r, unresolvedDeps, from)
   270  
   271  		for imp, err := range unresolvedDeps {
   272  			if err == nil {
   273  				continue
   274  			}
   275  			log.Printf("%[1]v (%[2]s): warning: failed to resolve %[3]q: %v", from, r.Kind(), imp, err)
   276  		}
   277  	}
   278  }
   279  
   280  // findRuleWithOverride is the same shape of resolve.FindRuleWithOverride.
   281  type findRuleWithOverride func(c *config.Config, imp resolve.ImportSpec, lang string) (label.Label, bool)
   282  
   283  // findRulesByImportWithConfig is the same shape of resolve.RuleIndex.FindRulesByImportWithConfig.
   284  // For testability want to avoid the RuleIndex as it is fundamentally tied to the resolve.resolveConfig,
   285  // which is private and not easily mocked.
   286  type findRulesByImportWithConfig func(c *config.Config, imp resolve.ImportSpec, lang string) []resolve.FindResult
   287  
   288  // resolveScalaDeps attempts to resolve labels for the given deps under the
   289  // "scala" language.  Only unresolved deps of type ErrNoLabel are considered.
   290  // Typically these unresolved dependencies arise from (scalapb.options) imports.
   291  func resolveScalaDeps(
   292  	findRuleWithOverride findRuleWithOverride,
   293  	findRulesByImportWithConfig findRulesByImportWithConfig,
   294  	c *config.Config,
   295  	r *rule.Rule,
   296  	unresolvedDeps map[string]error,
   297  	from label.Label,
   298  ) {
   299  
   300  	resolvedDeps := make([]string, 0)
   301  
   302  	markResolved := func(imp string, to label.Label) {
   303  		delete(unresolvedDeps, imp)
   304  		if to == from {
   305  			return
   306  		}
   307  		resolvedDeps = append(resolvedDeps, to.String())
   308  	}
   309  
   310  	for imp, err := range unresolvedDeps {
   311  		if err != protoc.ErrNoLabel {
   312  			continue
   313  		}
   314  		importSpec := resolve.ImportSpec{Lang: scalaLangName, Imp: imp}
   315  		if l, ok := findRuleWithOverride(c, importSpec, scalaLangName); ok {
   316  			markResolved(imp, l)
   317  			continue
   318  		}
   319  		result := findRulesByImportWithConfig(c, importSpec, scalaLangName)
   320  		if len(result) == 0 {
   321  			continue
   322  		}
   323  		if len(result) > 1 {
   324  			log.Println(from, "multiple rules matched for scala import %q: %v", imp, result)
   325  			continue
   326  		}
   327  		markResolved(imp, result[0].Label)
   328  	}
   329  	if len(resolvedDeps) > 0 {
   330  		r.SetAttr("deps", protoc.DeduplicateAndSort(append(r.AttrStrings("deps"), resolvedDeps...)))
   331  	}
   332  }
   333  
   334  func getScalapbImports(files []*protoc.File) []string {
   335  	imps := make([]string, 0)
   336  
   337  	for _, file := range files {
   338  		for _, option := range file.Options() {
   339  			if option.Name != scalapbOptionsName {
   340  				continue
   341  			}
   342  			for _, namedLiteral := range option.Constant.OrderedMap {
   343  				switch namedLiteral.Name {
   344  				case "import":
   345  					if namedLiteral.Source != "" {
   346  						imps = append(imps, parseScalaImportNamedLiteral(namedLiteral.Source)...)
   347  					}
   348  				}
   349  			}
   350  		}
   351  		for _, msg := range file.Messages() {
   352  			for _, child := range msg.Elements {
   353  				if field, ok := child.(*proto.NormalField); ok {
   354  					for _, option := range field.Options {
   355  						if option.Name != scalapbFieldTypeName {
   356  							continue
   357  						}
   358  						if option.Constant.Source != "" {
   359  							imps = append(imps, option.Constant.Source)
   360  						}
   361  					}
   362  				}
   363  			}
   364  		}
   365  	}
   366  
   367  	return protoc.DeduplicateAndSort(imps)
   368  }
   369  
   370  func parseScalaImportNamedLiteral(lit string) (imports []string) {
   371  	ob := strings.Index(lit, "{")
   372  	cb := strings.Index(lit, "}")
   373  	if ob == -1 || cb == -1 {
   374  		return []string{lit}
   375  	}
   376  	prefix := strings.TrimRight(lit[:ob], ".")
   377  	exprs := strings.Split(lit[ob+1:cb], ",")
   378  	for _, expr := range exprs {
   379  		expr = strings.TrimSpace(expr)
   380  		parts := strings.Split(expr, "=>")
   381  		if len(parts) == 2 {
   382  			source := strings.TrimSpace(parts[0])
   383  			imports = append(imports, prefix+"."+source)
   384  		} else {
   385  			imports = append(imports, prefix+"."+expr)
   386  
   387  		}
   388  	}
   389  	return
   390  }
   391  
   392  // javaPackageOption is a utility function to seek for the java_package option.
   393  func javaPackageOption(options []proto.Option) (string, bool) {
   394  	for _, opt := range options {
   395  		if opt.Name != "java_package" {
   396  			continue
   397  		}
   398  		return opt.Constant.Source, true
   399  	}
   400  
   401  	return "", false
   402  }
   403  
   404  func provideScalaImports(files []*protoc.File, resolver protoc.ImportResolver, from label.Label, options map[string]bool) {
   405  	lang := "scala"
   406  
   407  	for _, file := range files {
   408  		pkgName := file.Package().Name
   409  		if javaPackageName, ok := javaPackageOption(file.Options()); ok {
   410  			pkgName = javaPackageName
   411  		}
   412  		if pkgName != "" {
   413  			resolver.Provide(lang, "package", pkgName, from)
   414  		}
   415  		for _, e := range file.Enums() {
   416  			name := e.Name
   417  			if pkgName != "" {
   418  				name = pkgName + "." + name
   419  			}
   420  			resolver.Provide(lang, "enum", name, from)
   421  			for _, value := range e.Elements {
   422  				if field, ok := value.(*proto.EnumField); ok {
   423  					fieldName := name + "." + field.Name
   424  					resolver.Provide(lang, lang, fieldName, from)
   425  				}
   426  			}
   427  		}
   428  		for _, m := range file.Messages() {
   429  			name := m.Name
   430  			if pkgName != "" {
   431  				name = pkgName + "." + name
   432  			}
   433  			resolver.Provide(lang, "message", name, from)
   434  			resolver.Provide(lang, "message", name+"Proto", from)
   435  		}
   436  		for _, s := range file.Services() {
   437  			name := s.Name
   438  			if pkgName != "" {
   439  				name = pkgName + "." + name
   440  			}
   441  			resolver.Provide(lang, "service", name, from)
   442  			resolver.Provide(lang, "service", name+"Grpc", from)
   443  			resolver.Provide(lang, "service", name+"Proto", from)
   444  			resolver.Provide(lang, "service", name+"Client", from)
   445  			resolver.Provide(lang, "service", name+"Handler", from)
   446  			resolver.Provide(lang, "service", name+"Server", from)
   447  			// TOOD: if this is configured on the proto_plugin, we won't know
   448  			// about the plugin option.  Advertise them anyway.
   449  			// if options["server_power_apis"] {
   450  			resolver.Provide(lang, "service", name+"PowerApi", from)
   451  			resolver.Provide(lang, "service", name+"PowerApiHandler", from)
   452  			resolver.Provide(lang, "service", name+"ClientPowerApi", from)
   453  			// }
   454  		}
   455  	}
   456  }
   457  
   458  // scalaLibraryOptions represents the parsed flag configuration for a scalaLibrary
   459  type scalaLibraryOptions struct {
   460  	noResolve        map[string]bool
   461  	exclude, include []string
   462  	plugins          []string
   463  	resolveWKTs      bool
   464  }
   465  
   466  func parseScalaLibraryOptions(kindName string, args []string) *scalaLibraryOptions {
   467  	flags := flag.NewFlagSet(kindName, flag.ExitOnError)
   468  
   469  	var noresolveFlagValue string
   470  	flags.StringVar(&noresolveFlagValue, "noresolve", "", "--noresolve=<path>.proto suppresses deps resolution of <path>.proto")
   471  
   472  	var resolveWKTs bool
   473  	flags.BoolVar(&resolveWKTs, "resolve_well_known_types", false, "--resolve_well_known_types=true enables resolution of well-known-types")
   474  
   475  	var excludeFlagValue string
   476  	flags.StringVar(&excludeFlagValue, "exclude", "", "--exclude=<file>.srcjar suppresses rule output for <glob>.srcjar.  If after removing all matching files, no outputs remain, the rule will not be emitted.")
   477  
   478  	var includeFlagValue string
   479  	flags.StringVar(&includeFlagValue, "include", "", "--include=<file>.srcjar keeps only rule output for <glob>.srcjar.  If after removing all matching files, no outputs remain, the rule will not be emitted.")
   480  
   481  	var pluginsFlagValue string
   482  	flags.StringVar(&pluginsFlagValue, "plugins", "", "--plugins=name1,name2 includes only those files generated by the given plugin names")
   483  
   484  	if err := flags.Parse(args); err != nil {
   485  		log.Fatalf("failed to parse flags for %q: %v", kindName, err)
   486  	}
   487  
   488  	config := &scalaLibraryOptions{
   489  		noResolve:   make(map[string]bool),
   490  		resolveWKTs: resolveWKTs,
   491  	}
   492  
   493  	for _, value := range strings.Split(noresolveFlagValue, ",") {
   494  		config.noResolve[value] = true
   495  	}
   496  	if len(excludeFlagValue) > 0 {
   497  		config.exclude = strings.Split(excludeFlagValue, ",")
   498  	}
   499  	if len(includeFlagValue) > 0 {
   500  		config.include = strings.Split(includeFlagValue, ",")
   501  	}
   502  	if len(pluginsFlagValue) > 0 {
   503  		config.plugins = strings.Split(pluginsFlagValue, ",")
   504  	}
   505  
   506  	return config
   507  }
   508  
   509  func (o *scalaLibraryOptions) filterOutputs(in []string) (out []string) {
   510  	if len(o.include) > 0 {
   511  		log.Printf("filtering includes %v %d %q", o.include, len(o.include), o.include[0])
   512  		files := make([]string, 0)
   513  
   514  		for _, value := range in {
   515  			var shouldInclude bool
   516  			for _, pattern := range o.include {
   517  				match, err := doublestar.PathMatch(pattern, value)
   518  				if err != nil {
   519  					log.Fatalf("bad --include pattern %q: %v", pattern, err)
   520  				}
   521  				if match {
   522  					shouldInclude = true
   523  					break
   524  				}
   525  			}
   526  			if shouldInclude {
   527  				files = append(files, value)
   528  			}
   529  		}
   530  
   531  		in = files
   532  	}
   533  
   534  next:
   535  	for _, value := range in {
   536  		for _, pattern := range o.exclude {
   537  			match, err := doublestar.PathMatch(pattern, value)
   538  			if err != nil {
   539  				log.Fatalf("bad --exclude pattern %q: %v", pattern, err)
   540  			}
   541  			if match {
   542  				continue next
   543  			}
   544  		}
   545  		out = append(out, value)
   546  	}
   547  
   548  	return
   549  }
   550  
   551  func (o *scalaLibraryOptions) filterImports(in []string) (out []string) {
   552  	for _, value := range in {
   553  		if o.noResolve[value] {
   554  			continue
   555  		}
   556  		out = append(out, value)
   557  	}
   558  	return
   559  }
   560  
   561  func messageFiles(in []*protoc.File) []*protoc.File {
   562  	return filterFiles(in, func(f *protoc.File) bool {
   563  		return !f.HasServices()
   564  	})
   565  }
   566  
   567  func serviceFiles(in []*protoc.File) []*protoc.File {
   568  	return filterFiles(in, func(f *protoc.File) bool {
   569  		return f.HasServices()
   570  	})
   571  }
   572  
   573  func filterFiles(in []*protoc.File, want func(f *protoc.File) bool) []*protoc.File {
   574  	out := make([]*protoc.File, 0, len(in))
   575  	for _, file := range in {
   576  		if want(file) {
   577  			out = append(out, file)
   578  		}
   579  	}
   580  	return out
   581  }
   582  
   583  func getPluginConfiguration(plugins []*protoc.PluginConfiguration, name string) *protoc.PluginConfiguration {
   584  	for _, plugin := range plugins {
   585  		if plugin.Config.Name == name {
   586  			return plugin
   587  		}
   588  	}
   589  	return nil
   590  }