github.com/stackb/rules_proto@v0.0.0-20240221195024-5428336c51f1/pkg/protoc/package.go (about)

     1  package protoc
     2  
     3  import (
     4  	"log"
     5  	"path"
     6  	"sort"
     7  
     8  	"github.com/bazelbuild/bazel-gazelle/config"
     9  	"github.com/bazelbuild/bazel-gazelle/label"
    10  	"github.com/bazelbuild/bazel-gazelle/rule"
    11  )
    12  
    13  // Package provides a set of proto_library derived rules for the package.
    14  type Package struct {
    15  	// relative path of build file
    16  	rel string
    17  	// the config for this package
    18  	cfg *PackageConfig
    19  	// list of proto_library targets in the package
    20  	libs []ProtoLibrary
    21  	// computed providers
    22  	gen, empty []RuleProvider
    23  	// ruleLibs records the ProtoLibrary a RuleProvider was built on.
    24  	ruleLibs map[RuleProvider]ProtoLibrary
    25  	// providers record the provider of a rule, by rule name.
    26  	providers map[string]RuleProvider
    27  }
    28  
    29  // NewPackage constructs a Package given a list of proto_library rules
    30  // in the package.
    31  func NewPackage(rel string, cfg *PackageConfig, libs ...ProtoLibrary) *Package {
    32  	s := &Package{
    33  		rel:       rel,
    34  		cfg:       cfg,
    35  		libs:      libs,
    36  		ruleLibs:  make(map[RuleProvider]ProtoLibrary),
    37  		providers: make(map[string]RuleProvider),
    38  	}
    39  	s.gen = s.generateRules(true)
    40  	s.empty = s.generateRules(false)
    41  	return s
    42  }
    43  
    44  // generateRules constructs a list of rules based on the configured set of
    45  // languages.
    46  func (s *Package) generateRules(enabled bool) []RuleProvider {
    47  	rules := make([]RuleProvider, 0)
    48  	langs := s.cfg.configuredLangs()
    49  
    50  	for _, lang := range langs {
    51  		if enabled != lang.Enabled {
    52  			continue
    53  		}
    54  		for _, lib := range s.libs {
    55  			rules = append(rules, s.libraryRules(lang, lib)...)
    56  		}
    57  	}
    58  	return rules
    59  }
    60  
    61  func (s *Package) libraryRules(p *LanguageConfig, lib ProtoLibrary) []RuleProvider {
    62  	// list of plugin configurations that apply to this proto_library
    63  	configs := make([]*PluginConfiguration, 0)
    64  
    65  	for name, want := range p.Plugins {
    66  		if !want {
    67  			continue
    68  		}
    69  		plugin, ok := s.cfg.plugins[name]
    70  		if !ok {
    71  			log.Fatalf("plugin not configured: %q", name)
    72  		}
    73  		if !plugin.Enabled {
    74  			continue
    75  		}
    76  
    77  		ctx := &PluginContext{
    78  			Rel:           s.rel,
    79  			ProtoLibrary:  lib,
    80  			PackageConfig: *s.cfg,
    81  			PluginConfig:  *plugin,
    82  		}
    83  
    84  		if plugin.Implementation == "" {
    85  			plugin.Implementation = plugin.Name
    86  		}
    87  		impl, err := globalRegistry.LookupPlugin(plugin.Implementation)
    88  
    89  		if err == ErrUnknownPlugin {
    90  			log.Fatalf(
    91  				"%s: plugin not registered: %q (available: %v) [%+v]",
    92  				s.rel,
    93  				plugin.Implementation,
    94  				globalRegistry.PluginNames(),
    95  				plugin,
    96  			)
    97  		}
    98  		ctx.Plugin = impl
    99  
   100  		// Delegate to the implementation for configuration
   101  		config := impl.Configure(ctx)
   102  		if config == nil {
   103  			continue
   104  		}
   105  		config.Plugin = impl
   106  		config.Config = plugin.clone()
   107  		config.Options = DeduplicateAndSort(config.Options)
   108  
   109  		// plugin.Label overrides the default value from the implementation
   110  		if plugin.Label.Name != "" {
   111  			config.Label = plugin.Label
   112  		}
   113  
   114  		configs = append(configs, config)
   115  	}
   116  
   117  	if len(configs) == 0 {
   118  		return nil
   119  	}
   120  
   121  	imports := make([]string, len(lib.Files()))
   122  	for i, file := range lib.Files() {
   123  		imports[i] = path.Join(file.Dir, file.Basename)
   124  	}
   125  
   126  	rules := make([]RuleProvider, 0)
   127  
   128  	pc := newProtocConfiguration(s.cfg, p, s.cfg.Config.WorkDir, s.rel, p.Name, lib, configs)
   129  	for _, name := range p.GetRulesByIntent(true) {
   130  		ruleConfig, ok := s.cfg.rules[name]
   131  		if !ok {
   132  			names := make([]string, 0)
   133  			for name := range s.cfg.rules {
   134  				names = append(names, name)
   135  			}
   136  			log.Fatalf("proto_rule %q is not configured (available: %v)", name, names)
   137  		}
   138  		if !ruleConfig.Enabled {
   139  			continue
   140  		}
   141  
   142  		impl, err := globalRegistry.LookupRule(ruleConfig.Implementation)
   143  		if err == ErrUnknownRule {
   144  			log.Fatalf(
   145  				"%s: rule not registered: %q (available: %v)",
   146  				s.rel,
   147  				ruleConfig.Implementation,
   148  				globalRegistry.RuleNames(),
   149  			)
   150  		}
   151  		ruleConfig.Impl = impl
   152  
   153  		rule := impl.ProvideRule(ruleConfig, pc)
   154  		if rule == nil {
   155  			continue
   156  		}
   157  
   158  		s.ruleLibs[rule] = lib
   159  
   160  		rules = append(rules, rule)
   161  	}
   162  
   163  	return rules
   164  }
   165  
   166  // RuleProvider returns the provider of a rule or nil if not known.
   167  func (s *Package) RuleProvider(r *rule.Rule) RuleProvider {
   168  	if provider, ok := s.providers[r.Name()]; ok {
   169  		return provider
   170  	}
   171  	return nil
   172  }
   173  
   174  // Rules provides the aggregated rule list for the package.
   175  func (s *Package) Rules() []*rule.Rule {
   176  	return s.getProvidedRules(s.gen, true)
   177  }
   178  
   179  // Empty names the rules that can be deleted.
   180  func (s *Package) Empty() []*rule.Rule {
   181  	// it's a bit sad that we construct the full rules only for their kind and
   182  	// name, but that's how it is right now.
   183  	rules := s.getProvidedRules(s.empty, false)
   184  
   185  	empty := make([]*rule.Rule, len(rules))
   186  	for i, r := range rules {
   187  		empty[i] = rule.NewRule(r.Kind(), r.Name())
   188  	}
   189  
   190  	return empty
   191  }
   192  
   193  func (s *Package) getProvidedRules(providers []RuleProvider, shouldResolve bool) []*rule.Rule {
   194  	rules := make([]*rule.Rule, 0)
   195  	ruleIndexes := make(map[label.Label]int)
   196  
   197  	for _, p := range providers {
   198  		r := p.Rule(rules...)
   199  		if r == nil {
   200  			continue
   201  		}
   202  
   203  		if shouldResolve {
   204  			lib := s.ruleLibs[p]
   205  			r.SetPrivateAttr(ProtoLibraryKey, lib)
   206  			// package up imports, append those that might already be created.
   207  			imports := lib.Imports()
   208  			if existingImports, ok := r.PrivateAttr(config.GazelleImportsKey).([]string); ok {
   209  				imports = append(imports, existingImports...)
   210  			}
   211  			r.SetPrivateAttr(config.GazelleImportsKey, imports)
   212  		}
   213  
   214  		// if this is a duplicate (e.g. the rule provider returned an "other"
   215  		// rule), update the slice position, otherwise extend the rules slice.
   216  		from := label.New("", s.rel, r.Name())
   217  		if index, ok := ruleIndexes[from]; ok {
   218  			rules[index] = r
   219  		} else {
   220  			// record the association of the rule provider here for the
   221  			// resolver.  Only the first occurrence of this rule name gets
   222  			// associated with the provider.  The `go_library.go` file relies on
   223  			// this behavior when merging rules.
   224  			s.providers[r.Name()] = p
   225  
   226  			ruleIndexes[from] = len(rules)
   227  			rules = append(rules, r)
   228  		}
   229  	}
   230  
   231  	if shouldResolve {
   232  		file := rule.EmptyFile("", s.rel)
   233  		for _, r := range rules {
   234  			provider := s.providers[r.Name()]
   235  			from := label.New("", s.rel, r.Name())
   236  			provideResolverImportSpecs(s.cfg.Config, provider, r, file, from)
   237  		}
   238  	}
   239  
   240  	return rules
   241  }
   242  
   243  func provideResolverImportSpecs(c *config.Config, provider RuleProvider, r *rule.Rule, f *rule.File, from label.Label) {
   244  	for _, imp := range provider.Imports(c, r, f) {
   245  		GlobalResolver().Provide(
   246  			"protobuf",
   247  			imp.Lang,
   248  			imp.Imp,
   249  			from,
   250  		)
   251  	}
   252  }
   253  
   254  // DeduplicateAndSort removes duplicate entries and sorts the list
   255  func DeduplicateAndSort(in []string) (out []string) {
   256  	if len(in) == 0 {
   257  		return in
   258  	}
   259  	seen := make(map[string]bool)
   260  	for _, v := range in {
   261  		if seen[v] {
   262  			continue
   263  		}
   264  		seen[v] = true
   265  		out = append(out, v)
   266  	}
   267  	sort.Strings(out)
   268  	return
   269  }