github.com/stackb/rules_proto@v0.0.0-20240221195024-5428336c51f1/pkg/rule/rules_scala/scala_library.go (about) 1 package rules_scala 2 3 import ( 4 "flag" 5 "fmt" 6 "log" 7 "strings" 8 9 "github.com/bazelbuild/bazel-gazelle/config" 10 "github.com/bazelbuild/bazel-gazelle/label" 11 "github.com/bazelbuild/bazel-gazelle/resolve" 12 "github.com/bazelbuild/bazel-gazelle/rule" 13 "github.com/bmatcuk/doublestar" 14 "github.com/emicklei/proto" 15 16 "github.com/stackb/rules_proto/pkg/plugin/akka/akka_grpc" 17 "github.com/stackb/rules_proto/pkg/plugin/scalapb/scalapb" 18 "github.com/stackb/rules_proto/pkg/protoc" 19 ) 20 21 const ( 22 GrpcscalaLibraryRuleName = "grpc_scala_library" 23 ProtoscalaLibraryRuleName = "proto_scala_library" 24 protoScalaLibraryRuleSuffix = "_proto_scala_library" 25 grpcScalaLibraryRuleSuffix = "_grpc_scala_library" 26 scalaPbPluginOptionsPrivateKey = "_scalapb_plugin" 27 akkaGrpcPluginOptionsPrivateKey = "_akka_grpc_plugin" 28 scalapbOptionsName = "(scalapb.options)" 29 scalapbFieldTypeName = "(scalapb.field).type" 30 scalaLangName = "scala" 31 ) 32 33 func init() { 34 protoc.Rules().MustRegisterRule("stackb:rules_proto:"+ProtoscalaLibraryRuleName, 35 &scalaLibrary{ 36 kindName: ProtoscalaLibraryRuleName, 37 ruleSuffix: protoScalaLibraryRuleSuffix, 38 protoFileFilter: messageFiles, 39 }) 40 protoc.Rules().MustRegisterRule("stackb:rules_proto:"+GrpcscalaLibraryRuleName, 41 &scalaLibrary{ 42 kindName: GrpcscalaLibraryRuleName, 43 ruleSuffix: grpcScalaLibraryRuleSuffix, 44 protoFileFilter: serviceFiles, 45 }) 46 } 47 48 // scalaLibrary implements LanguageRule for the 'proto_scala_library' rule from 49 // @rules_proto. 50 type scalaLibrary struct { 51 kindName string 52 ruleSuffix string 53 protoFileFilter func([]*protoc.File) []*protoc.File 54 } 55 56 // Name implements part of the LanguageRule interface. 57 func (s *scalaLibrary) Name() string { 58 return s.kindName 59 } 60 61 // KindInfo implements part of the LanguageRule interface. 62 func (s *scalaLibrary) KindInfo() rule.KindInfo { 63 return rule.KindInfo{ 64 MergeableAttrs: map[string]bool{ 65 "srcs": true, 66 "exports": true, 67 }, 68 NonEmptyAttrs: map[string]bool{ 69 "srcs": true, 70 }, 71 ResolveAttrs: map[string]bool{ 72 "deps": true, 73 }, 74 } 75 } 76 77 // LoadInfo implements part of the LanguageRule interface. 78 func (s *scalaLibrary) LoadInfo() rule.LoadInfo { 79 return rule.LoadInfo{ 80 Name: fmt.Sprintf("@build_stack_rules_proto//rules/scala:%s.bzl", s.kindName), 81 Symbols: []string{s.kindName}, 82 } 83 } 84 85 // ProvideRule implements part of the LanguageRule interface. 86 func (s *scalaLibrary) ProvideRule(cfg *protoc.LanguageRuleConfig, pc *protoc.ProtocConfiguration) protoc.RuleProvider { 87 files := s.protoFileFilter(pc.Library.Files()) 88 if len(files) == 0 { 89 return nil 90 } 91 92 options := parseScalaLibraryOptions(s.kindName, cfg.GetOptions()) 93 94 // the list of output files 95 outputs := make([]string, 0) 96 97 if len(options.plugins) == 0 { 98 log.Printf("warning: the rule %s should have at least one plugin name for the --plugins option. This informs the rule which plugin(s) outputs correspond to this library rule", s.Name()) 99 } 100 101 for _, name := range options.plugins { 102 plugin := getPluginConfiguration(pc.Plugins, name) 103 if plugin == nil { 104 // TODO: warn here? 105 continue 106 } 107 outputs = append(outputs, plugin.Outputs...) 108 } 109 110 outputs = options.filterOutputs(outputs) 111 112 if len(outputs) == 0 { 113 return nil 114 } 115 116 return &scalaLibraryRule{ 117 kindName: s.kindName, 118 ruleNameSuffix: s.ruleSuffix, 119 options: options, 120 outputs: outputs, 121 ruleConfig: cfg, 122 config: pc, 123 files: files, 124 } 125 } 126 127 // scalaLibraryRule implements RuleProvider for 'scala_library'-derived rules. 128 type scalaLibraryRule struct { 129 kindName string 130 ruleNameSuffix string 131 outputs []string 132 config *protoc.ProtocConfiguration 133 ruleConfig *protoc.LanguageRuleConfig 134 options *scalaLibraryOptions 135 files []*protoc.File 136 } 137 138 // Kind implements part of the ruleProvider interface. 139 func (s *scalaLibraryRule) Kind() string { 140 return s.kindName 141 } 142 143 // Name implements part of the ruleProvider interface. 144 func (s *scalaLibraryRule) Name() string { 145 return s.config.Library.BaseName() + s.ruleNameSuffix 146 } 147 148 // Srcs computes the srcs list for the rule. 149 func (s *scalaLibraryRule) Srcs() []string { 150 srcs := make([]string, 0) 151 for _, output := range s.outputs { 152 if strings.HasSuffix(output, ".srcjar") { 153 srcs = append(srcs, protoc.StripRel(s.config.Rel, output)) 154 } 155 } 156 return srcs 157 } 158 159 // Deps computes the deps list for the rule. 160 func (s *scalaLibraryRule) Deps() []string { 161 deps := s.ruleConfig.GetDeps() 162 163 for _, pluginConfig := range s.config.Plugins { 164 deps = append(deps, pluginConfig.Config.GetDeps()...) 165 } 166 167 return protoc.DeduplicateAndSort(deps) 168 } 169 170 // Visibility provides visibility labels. 171 func (s *scalaLibraryRule) Visibility() []string { 172 return s.ruleConfig.GetVisibility() 173 } 174 175 // Rule implements part of the ruleProvider interface. 176 func (s *scalaLibraryRule) Rule(otherGen ...*rule.Rule) *rule.Rule { 177 newRule := rule.NewRule(s.Kind(), s.Name()) 178 179 newRule.SetAttr("srcs", s.Srcs()) 180 181 deps := s.Deps() 182 if len(deps) > 0 { 183 newRule.SetAttr("deps", deps) 184 } 185 186 exports := s.ruleConfig.GetAttr("exports") 187 if len(exports) > 0 { 188 newRule.SetAttr("exports", exports) 189 } 190 191 visibility := s.Visibility() 192 if len(visibility) > 0 { 193 newRule.SetAttr("visibility", visibility) 194 } 195 196 // add any imports from proto options. Example: option (scalapb.options) = 197 // { 198 // import: "com.foo.Bar" 199 // }; 200 // 201 // NOTE: we pass *all* files from the proto_library. Although the 202 // fileFilter has reduced the set into grpc or non-grpc ones, in practice 203 // protoc-gen-scala only has the "grpc" option. When OFF, it will produce a 204 // srcjar with only messages. When that is ON, the compiler will produce a 205 // srcjar with both messages and services. There is no way to tell the 206 // compiler to generate ONLY services (and not messages). Therefore, we 207 // need all dependencies in order to compile the messages. 208 scalaImports := getScalapbImports(s.config.Library.Files()) 209 if len(scalaImports) > 0 { 210 newRule.SetPrivateAttr(config.GazelleImportsKey, scalaImports) 211 } 212 213 // set the override language such that deps of 'proto_scala_library' and 214 // 'grpc_scala_library' can resolve together (matches the value used by 215 // "Imports"). 216 newRule.SetPrivateAttr(protoc.ResolverImpLangPrivateKey, "scala") 217 218 // add the scalapb plugin options as a private attr so we can inspect them 219 // during the .Imports() phase. For example, akka 'server_power_apis' 220 // generates additional classes. 221 scalaPbPlugin := s.config.GetPluginConfiguration(scalapb.ScalaPBPluginName) 222 if scalaPbPlugin != nil { 223 newRule.SetPrivateAttr(scalaPbPluginOptionsPrivateKey, scalaPbPlugin.Options) 224 } 225 akkaGrpcPlugin := s.config.GetPluginConfiguration(akka_grpc.AkkaGrpcPluginName) 226 if akkaGrpcPlugin != nil { 227 newRule.SetPrivateAttr(akkaGrpcPluginOptionsPrivateKey, akkaGrpcPlugin.Options) 228 } 229 230 return newRule 231 } 232 233 // Imports implements part of the RuleProvider interface. 234 func (s *scalaLibraryRule) Imports(c *config.Config, r *rule.Rule, file *rule.File) []resolve.ImportSpec { 235 // 1. provide generated scala class names for message and services for 236 // 'scala scala' deps. This will allow a scala extension to resolve proto 237 // deps when they import scala proto class names. 238 pluginOptions := make(map[string]bool) 239 if scalaPbPluginOptions, ok := r.PrivateAttr(scalaPbPluginOptionsPrivateKey).([]string); ok { 240 for _, opt := range scalaPbPluginOptions { 241 pluginOptions[opt] = true 242 } 243 } 244 if akkaGrpcPluginOptions, ok := r.PrivateAttr(akkaGrpcPluginOptionsPrivateKey).([]string); ok { 245 for _, opt := range akkaGrpcPluginOptions { 246 pluginOptions[opt] = true 247 } 248 } 249 from := label.New("", file.Pkg, r.Name()) 250 251 provideScalaImports(s.files, protoc.GlobalResolver(), from, pluginOptions) 252 253 // 2. create import specs for 'protobuf scala'. This allows 254 // proto_scala_library and grpc_scala_library to resolve deps. 255 return protoc.ProtoFilesImportSpecsForKind("scala", s.files) 256 } 257 258 // Resolve implements part of the RuleProvider interface. 259 func (s *scalaLibraryRule) Resolve(c *config.Config, ix *resolve.RuleIndex, r *rule.Rule, imports []string, from label.Label) { 260 imports = s.options.filterImports(imports) 261 262 resolveFn := protoc.ResolveDepsAttr("deps", !s.options.resolveWKTs) 263 resolveFn(c, ix, r, imports, from) 264 265 if unresolvedDeps, ok := r.PrivateAttr(protoc.UnresolvedDepsPrivateKey).(map[string]error); ok { 266 if from.Repo == c.RepoName { 267 from.Repo = "" 268 } 269 resolveScalaDeps(resolve.FindRuleWithOverride, ix.FindRulesByImportWithConfig, c, r, unresolvedDeps, from) 270 271 for imp, err := range unresolvedDeps { 272 if err == nil { 273 continue 274 } 275 log.Printf("%[1]v (%[2]s): warning: failed to resolve %[3]q: %v", from, r.Kind(), imp, err) 276 } 277 } 278 } 279 280 // findRuleWithOverride is the same shape of resolve.FindRuleWithOverride. 281 type findRuleWithOverride func(c *config.Config, imp resolve.ImportSpec, lang string) (label.Label, bool) 282 283 // findRulesByImportWithConfig is the same shape of resolve.RuleIndex.FindRulesByImportWithConfig. 284 // For testability want to avoid the RuleIndex as it is fundamentally tied to the resolve.resolveConfig, 285 // which is private and not easily mocked. 286 type findRulesByImportWithConfig func(c *config.Config, imp resolve.ImportSpec, lang string) []resolve.FindResult 287 288 // resolveScalaDeps attempts to resolve labels for the given deps under the 289 // "scala" language. Only unresolved deps of type ErrNoLabel are considered. 290 // Typically these unresolved dependencies arise from (scalapb.options) imports. 291 func resolveScalaDeps( 292 findRuleWithOverride findRuleWithOverride, 293 findRulesByImportWithConfig findRulesByImportWithConfig, 294 c *config.Config, 295 r *rule.Rule, 296 unresolvedDeps map[string]error, 297 from label.Label, 298 ) { 299 300 resolvedDeps := make([]string, 0) 301 302 markResolved := func(imp string, to label.Label) { 303 delete(unresolvedDeps, imp) 304 if to == from { 305 return 306 } 307 resolvedDeps = append(resolvedDeps, to.String()) 308 } 309 310 for imp, err := range unresolvedDeps { 311 if err != protoc.ErrNoLabel { 312 continue 313 } 314 importSpec := resolve.ImportSpec{Lang: scalaLangName, Imp: imp} 315 if l, ok := findRuleWithOverride(c, importSpec, scalaLangName); ok { 316 markResolved(imp, l) 317 continue 318 } 319 result := findRulesByImportWithConfig(c, importSpec, scalaLangName) 320 if len(result) == 0 { 321 continue 322 } 323 if len(result) > 1 { 324 log.Println(from, "multiple rules matched for scala import %q: %v", imp, result) 325 continue 326 } 327 markResolved(imp, result[0].Label) 328 } 329 if len(resolvedDeps) > 0 { 330 r.SetAttr("deps", protoc.DeduplicateAndSort(append(r.AttrStrings("deps"), resolvedDeps...))) 331 } 332 } 333 334 func getScalapbImports(files []*protoc.File) []string { 335 imps := make([]string, 0) 336 337 for _, file := range files { 338 for _, option := range file.Options() { 339 if option.Name != scalapbOptionsName { 340 continue 341 } 342 for _, namedLiteral := range option.Constant.OrderedMap { 343 switch namedLiteral.Name { 344 case "import": 345 if namedLiteral.Source != "" { 346 imps = append(imps, parseScalaImportNamedLiteral(namedLiteral.Source)...) 347 } 348 } 349 } 350 } 351 for _, msg := range file.Messages() { 352 for _, child := range msg.Elements { 353 if field, ok := child.(*proto.NormalField); ok { 354 for _, option := range field.Options { 355 if option.Name != scalapbFieldTypeName { 356 continue 357 } 358 if option.Constant.Source != "" { 359 imps = append(imps, option.Constant.Source) 360 } 361 } 362 } 363 } 364 } 365 } 366 367 return protoc.DeduplicateAndSort(imps) 368 } 369 370 func parseScalaImportNamedLiteral(lit string) (imports []string) { 371 ob := strings.Index(lit, "{") 372 cb := strings.Index(lit, "}") 373 if ob == -1 || cb == -1 { 374 return []string{lit} 375 } 376 prefix := strings.TrimRight(lit[:ob], ".") 377 exprs := strings.Split(lit[ob+1:cb], ",") 378 for _, expr := range exprs { 379 expr = strings.TrimSpace(expr) 380 parts := strings.Split(expr, "=>") 381 if len(parts) == 2 { 382 source := strings.TrimSpace(parts[0]) 383 imports = append(imports, prefix+"."+source) 384 } else { 385 imports = append(imports, prefix+"."+expr) 386 387 } 388 } 389 return 390 } 391 392 // javaPackageOption is a utility function to seek for the java_package option. 393 func javaPackageOption(options []proto.Option) (string, bool) { 394 for _, opt := range options { 395 if opt.Name != "java_package" { 396 continue 397 } 398 return opt.Constant.Source, true 399 } 400 401 return "", false 402 } 403 404 func provideScalaImports(files []*protoc.File, resolver protoc.ImportResolver, from label.Label, options map[string]bool) { 405 lang := "scala" 406 407 for _, file := range files { 408 pkgName := file.Package().Name 409 if javaPackageName, ok := javaPackageOption(file.Options()); ok { 410 pkgName = javaPackageName 411 } 412 if pkgName != "" { 413 resolver.Provide(lang, "package", pkgName, from) 414 } 415 for _, e := range file.Enums() { 416 name := e.Name 417 if pkgName != "" { 418 name = pkgName + "." + name 419 } 420 resolver.Provide(lang, "enum", name, from) 421 for _, value := range e.Elements { 422 if field, ok := value.(*proto.EnumField); ok { 423 fieldName := name + "." + field.Name 424 resolver.Provide(lang, lang, fieldName, from) 425 } 426 } 427 } 428 for _, m := range file.Messages() { 429 name := m.Name 430 if pkgName != "" { 431 name = pkgName + "." + name 432 } 433 resolver.Provide(lang, "message", name, from) 434 resolver.Provide(lang, "message", name+"Proto", from) 435 } 436 for _, s := range file.Services() { 437 name := s.Name 438 if pkgName != "" { 439 name = pkgName + "." + name 440 } 441 resolver.Provide(lang, "service", name, from) 442 resolver.Provide(lang, "service", name+"Grpc", from) 443 resolver.Provide(lang, "service", name+"Proto", from) 444 resolver.Provide(lang, "service", name+"Client", from) 445 resolver.Provide(lang, "service", name+"Handler", from) 446 resolver.Provide(lang, "service", name+"Server", from) 447 // TOOD: if this is configured on the proto_plugin, we won't know 448 // about the plugin option. Advertise them anyway. 449 // if options["server_power_apis"] { 450 resolver.Provide(lang, "service", name+"PowerApi", from) 451 resolver.Provide(lang, "service", name+"PowerApiHandler", from) 452 resolver.Provide(lang, "service", name+"ClientPowerApi", from) 453 // } 454 } 455 } 456 } 457 458 // scalaLibraryOptions represents the parsed flag configuration for a scalaLibrary 459 type scalaLibraryOptions struct { 460 noResolve map[string]bool 461 exclude, include []string 462 plugins []string 463 resolveWKTs bool 464 } 465 466 func parseScalaLibraryOptions(kindName string, args []string) *scalaLibraryOptions { 467 flags := flag.NewFlagSet(kindName, flag.ExitOnError) 468 469 var noresolveFlagValue string 470 flags.StringVar(&noresolveFlagValue, "noresolve", "", "--noresolve=<path>.proto suppresses deps resolution of <path>.proto") 471 472 var resolveWKTs bool 473 flags.BoolVar(&resolveWKTs, "resolve_well_known_types", false, "--resolve_well_known_types=true enables resolution of well-known-types") 474 475 var excludeFlagValue string 476 flags.StringVar(&excludeFlagValue, "exclude", "", "--exclude=<file>.srcjar suppresses rule output for <glob>.srcjar. If after removing all matching files, no outputs remain, the rule will not be emitted.") 477 478 var includeFlagValue string 479 flags.StringVar(&includeFlagValue, "include", "", "--include=<file>.srcjar keeps only rule output for <glob>.srcjar. If after removing all matching files, no outputs remain, the rule will not be emitted.") 480 481 var pluginsFlagValue string 482 flags.StringVar(&pluginsFlagValue, "plugins", "", "--plugins=name1,name2 includes only those files generated by the given plugin names") 483 484 if err := flags.Parse(args); err != nil { 485 log.Fatalf("failed to parse flags for %q: %v", kindName, err) 486 } 487 488 config := &scalaLibraryOptions{ 489 noResolve: make(map[string]bool), 490 resolveWKTs: resolveWKTs, 491 } 492 493 for _, value := range strings.Split(noresolveFlagValue, ",") { 494 config.noResolve[value] = true 495 } 496 if len(excludeFlagValue) > 0 { 497 config.exclude = strings.Split(excludeFlagValue, ",") 498 } 499 if len(includeFlagValue) > 0 { 500 config.include = strings.Split(includeFlagValue, ",") 501 } 502 if len(pluginsFlagValue) > 0 { 503 config.plugins = strings.Split(pluginsFlagValue, ",") 504 } 505 506 return config 507 } 508 509 func (o *scalaLibraryOptions) filterOutputs(in []string) (out []string) { 510 if len(o.include) > 0 { 511 log.Printf("filtering includes %v %d %q", o.include, len(o.include), o.include[0]) 512 files := make([]string, 0) 513 514 for _, value := range in { 515 var shouldInclude bool 516 for _, pattern := range o.include { 517 match, err := doublestar.PathMatch(pattern, value) 518 if err != nil { 519 log.Fatalf("bad --include pattern %q: %v", pattern, err) 520 } 521 if match { 522 shouldInclude = true 523 break 524 } 525 } 526 if shouldInclude { 527 files = append(files, value) 528 } 529 } 530 531 in = files 532 } 533 534 next: 535 for _, value := range in { 536 for _, pattern := range o.exclude { 537 match, err := doublestar.PathMatch(pattern, value) 538 if err != nil { 539 log.Fatalf("bad --exclude pattern %q: %v", pattern, err) 540 } 541 if match { 542 continue next 543 } 544 } 545 out = append(out, value) 546 } 547 548 return 549 } 550 551 func (o *scalaLibraryOptions) filterImports(in []string) (out []string) { 552 for _, value := range in { 553 if o.noResolve[value] { 554 continue 555 } 556 out = append(out, value) 557 } 558 return 559 } 560 561 func messageFiles(in []*protoc.File) []*protoc.File { 562 return filterFiles(in, func(f *protoc.File) bool { 563 return !f.HasServices() 564 }) 565 } 566 567 func serviceFiles(in []*protoc.File) []*protoc.File { 568 return filterFiles(in, func(f *protoc.File) bool { 569 return f.HasServices() 570 }) 571 } 572 573 func filterFiles(in []*protoc.File, want func(f *protoc.File) bool) []*protoc.File { 574 out := make([]*protoc.File, 0, len(in)) 575 for _, file := range in { 576 if want(file) { 577 out = append(out, file) 578 } 579 } 580 return out 581 } 582 583 func getPluginConfiguration(plugins []*protoc.PluginConfiguration, name string) *protoc.PluginConfiguration { 584 for _, plugin := range plugins { 585 if plugin.Config.Name == name { 586 return plugin 587 } 588 } 589 return nil 590 }