github.com/stackb/rules_proto@v0.0.0-20240221195024-5428336c51f1/pkg/protoc/starlark_rule.go (about)

     1  package protoc
     2  
     3  import (
     4  	"fmt"
     5  	"log"
     6  	"os"
     7  
     8  	"github.com/bazelbuild/bazel-gazelle/config"
     9  	"github.com/bazelbuild/bazel-gazelle/label"
    10  	"github.com/bazelbuild/bazel-gazelle/resolve"
    11  	"github.com/bazelbuild/bazel-gazelle/rule"
    12  	"go.starlark.net/starlark"
    13  	"go.starlark.net/starlarkstruct"
    14  )
    15  
    16  func LoadStarlarkLanguageRuleFromFile(workDir, filename, name string, reporter func(msg string), errorReporter func(err error)) (LanguageRule, error) {
    17  	filename, err := resolveStarlarkFilename(workDir, filename)
    18  	if err != nil {
    19  		return nil, err
    20  	}
    21  
    22  	f, err := os.Open(filename)
    23  	if err != nil {
    24  		return nil, fmt.Errorf("failed to open rule file %q: %w", filename, err)
    25  	}
    26  	defer f.Close()
    27  
    28  	return loadStarlarkLanguageRule(name, filename, f, reporter, errorReporter)
    29  }
    30  
    31  func loadStarlarkLanguageRule(name, filename string, src interface{}, reporter func(msg string), errorReporter func(err error)) (LanguageRule, error) {
    32  	newErrorf := func(msg string, args ...interface{}) error {
    33  		err := fmt.Errorf(filename+": "+msg, args...)
    34  		errorReporter(err)
    35  		return err
    36  	}
    37  
    38  	plugins := make(map[string]*starlarkstruct.Struct)
    39  	rules := make(map[string]*starlarkstruct.Struct)
    40  	predeclared := newPredeclared(plugins, rules)
    41  
    42  	_, thread, err := loadStarlarkProgram(filename, src, predeclared, reporter, errorReporter)
    43  	if err != nil {
    44  		return nil, err
    45  	}
    46  
    47  	if rule, ok := rules[name]; !ok {
    48  		return nil, newErrorf("rule %q was never declared", name)
    49  	} else {
    50  		return &starlarkLanguageRule{
    51  			name:          name,
    52  			rule:          rule,
    53  			reporter:      thread.Print,
    54  			errorReporter: newErrorf,
    55  		}, nil
    56  	}
    57  }
    58  
    59  func newStarlarkLanguageRuleFunction(rules map[string]*starlarkstruct.Struct) goStarlarkFunction {
    60  	return func(thread *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
    61  		var name string
    62  		var provideRule starlark.Callable
    63  		var loadInfo, kindInfo starlark.Callable
    64  
    65  		if err := starlark.UnpackArgs("Rule", args, kwargs,
    66  			"name", &name,
    67  			"load_info", &loadInfo,
    68  			"kind_info", &kindInfo,
    69  			"provide_rule", &provideRule,
    70  		); err != nil {
    71  			return nil, err
    72  		}
    73  
    74  		rule := starlarkstruct.FromStringDict(
    75  			Symbol("Rule"),
    76  			starlark.StringDict{
    77  				"name":         starlark.String(name),
    78  				"load_info":    loadInfo,
    79  				"kind_info":    kindInfo,
    80  				"provide_rule": provideRule,
    81  			},
    82  		)
    83  
    84  		rules[name] = rule
    85  		return rule, nil
    86  	}
    87  }
    88  
    89  // starlarkLanguageRule is a rule implemented in starlark that implements the protoc
    90  // LanguageRule interface.
    91  type starlarkLanguageRule struct {
    92  	name          string
    93  	reporter      func(thread *starlark.Thread, msg string)
    94  	errorReporter func(msg string, args ...interface{}) error
    95  	rule          *starlarkstruct.Struct
    96  }
    97  
    98  func (p *starlarkLanguageRule) Name() string {
    99  	return p.name
   100  }
   101  
   102  // LoadInfo returns the gazelle LoadInfo.
   103  func (p *starlarkLanguageRule) LoadInfo() (result rule.LoadInfo) {
   104  	callable, err := p.rule.Attr("load_info")
   105  	if err != nil {
   106  		log.Fatalf("LoadInfo() called on rule %q with no load_info function: %v", p.name, p.rule)
   107  		return result
   108  	}
   109  
   110  	thread := new(starlark.Thread)
   111  	thread.Print = p.reporter
   112  	value, err := starlark.Call(thread, callable, starlark.Tuple{}, []starlark.Tuple{})
   113  	if err != nil {
   114  		log.Fatalf("rule %q load_info failed: %v", p.name, err)
   115  		return result
   116  	}
   117  
   118  	switch value := value.(type) {
   119  	case *starlarkstruct.Struct:
   120  		result.Name = structAttrString(value, "name", p.errorReporter)
   121  		result.Symbols = structAttrStringSlice(value, "symbols", p.errorReporter)
   122  		result.After = structAttrStringSlice(value, "after", p.errorReporter)
   123  	default:
   124  		p.errorReporter("rule %q provide_rule returned invalid type: %T", p.name, value)
   125  	}
   126  	return
   127  }
   128  
   129  // KindInfo returns the gazelle KindInfo.
   130  func (p *starlarkLanguageRule) KindInfo() (result rule.KindInfo) {
   131  	callable, err := p.rule.Attr("kind_info")
   132  	if err != nil {
   133  		p.errorReporter("rule %q has no kind_info function", p.name)
   134  		return result
   135  	}
   136  
   137  	thread := new(starlark.Thread)
   138  	thread.Print = p.reporter
   139  	value, err := starlark.Call(thread, callable, starlark.Tuple{}, []starlark.Tuple{})
   140  	if err != nil {
   141  		p.errorReporter("rule %q kind_info failed: %w", p.name, err)
   142  		return result
   143  	}
   144  
   145  	switch value := value.(type) {
   146  	case *starlarkstruct.Struct:
   147  		result.MatchAny = structAttrBool(value, "match_any", p.errorReporter)
   148  		result.MatchAttrs = structAttrStringSlice(value, "match_attrs", p.errorReporter)
   149  		result.NonEmptyAttrs = structAttrMapStringBool(value, "non_empty_attrs", p.errorReporter)
   150  		result.SubstituteAttrs = structAttrMapStringBool(value, "substitute_attrs", p.errorReporter)
   151  		result.MergeableAttrs = structAttrMapStringBool(value, "mergeable_attrs", p.errorReporter)
   152  		result.ResolveAttrs = structAttrMapStringBool(value, "resolve_attrs", p.errorReporter)
   153  	default:
   154  		p.errorReporter("rule %q provide_rule returned invalid type: %T", p.name, value)
   155  	}
   156  
   157  	return
   158  }
   159  
   160  // ProvideRule takes the given configration and compilation and emits a
   161  // RuleProvider.  If the state of the ProtocConfiguration is such that the
   162  // rule should not be emitted, implementation should return nil.
   163  func (p *starlarkLanguageRule) ProvideRule(rc *LanguageRuleConfig, pc *ProtocConfiguration) RuleProvider {
   164  
   165  	callable, err := p.rule.Attr("provide_rule")
   166  	if err != nil {
   167  		p.errorReporter("rule %q has no provide_rule function", p.name)
   168  		return nil
   169  	}
   170  
   171  	thread := new(starlark.Thread)
   172  	thread.Print = p.reporter
   173  	value, err := starlark.Call(thread, callable, starlark.Tuple{
   174  		newLanguageRuleConfigStruct(rc),
   175  		newProtocConfigurationStruct(pc),
   176  	}, []starlark.Tuple{})
   177  	if err != nil {
   178  		p.errorReporter("rule %q provide_rule failed: %w", p.name, err)
   179  		return nil
   180  	}
   181  
   182  	var result RuleProvider
   183  	switch value := value.(type) {
   184  	case *starlarkstruct.Struct:
   185  		var experimentalResolveDepsAttr string
   186  		if attr, err := value.Attr("experimental_resolve_attr"); err == nil {
   187  			if str, ok := attr.(starlark.String); ok {
   188  				experimentalResolveDepsAttr = str.GoString()
   189  			}
   190  		}
   191  		result = &starlarkRuleProvider{
   192  			name:                        p.name,
   193  			provider:                    value,
   194  			reporter:                    p.reporter,
   195  			errorReporter:               p.errorReporter,
   196  			experimentalResolveDepsAttr: experimentalResolveDepsAttr,
   197  		}
   198  	default:
   199  		p.errorReporter("rule %q provide_rule returned invalid type: %T", p.name, value)
   200  		return nil
   201  	}
   202  
   203  	return result
   204  }
   205  
   206  // starlarkRuleProvider implements RuleProvider via a starlark struct.
   207  type starlarkRuleProvider struct {
   208  	name                        string
   209  	provider                    *starlarkstruct.Struct
   210  	reporter                    func(thread *starlark.Thread, msg string)
   211  	errorReporter               func(msg string, args ...interface{}) error
   212  	experimentalResolveDepsAttr string
   213  }
   214  
   215  // Kind implements part of the RuleProvider interface.
   216  func (s *starlarkRuleProvider) Kind() string {
   217  	kind, err := s.provider.Attr("kind")
   218  	if err != nil {
   219  		s.errorReporter("provider %q has no kind", s.Name())
   220  		return ""
   221  	}
   222  	return kind.(starlark.String).GoString()
   223  }
   224  
   225  // Name implements part of the RuleProvider interface.
   226  func (s *starlarkRuleProvider) Name() string {
   227  	return structAttrString(s.provider, "name", s.errorReporter)
   228  }
   229  
   230  // Rule implements part of the RuleProvider interface.
   231  func (s *starlarkRuleProvider) Rule(othergen ...*rule.Rule) *rule.Rule {
   232  	callable, err := s.provider.Attr("rule")
   233  	if err != nil {
   234  		s.errorReporter("rule %q has no rule() function", s.name)
   235  		return nil
   236  	}
   237  
   238  	thread := new(starlark.Thread)
   239  	s.reporter(thread, "Invoking rule "+s.name)
   240  	thread.Print = s.reporter
   241  	value, err := starlark.Call(thread, callable, starlark.Tuple{}, []starlark.Tuple{})
   242  	if err != nil {
   243  		s.errorReporter("provider %q rule() failed: %w", s.name, err)
   244  		return nil
   245  	}
   246  
   247  	switch value := value.(type) {
   248  	case starlark.NoneType:
   249  		return nil
   250  	case *starlarkstruct.Struct:
   251  		rKind := structAttrString(value, "kind", s.errorReporter)
   252  		if rKind == "" {
   253  			s.errorReporter("rule %q has no kind", s.name)
   254  			return nil
   255  		}
   256  		rName := structAttrString(value, "name", s.errorReporter)
   257  		if rName == "" {
   258  			s.errorReporter("rule %q has no name", s.name)
   259  			return nil
   260  		}
   261  		r := rule.NewRule(rKind, rName)
   262  		s.reporter(thread, "Created rule "+rKind+" "+rName)
   263  
   264  		attrs, err := value.Attr("attrs")
   265  		if err != nil {
   266  			s.errorReporter("provider %q rule() returned invalid type: %T", s.name, value)
   267  			return nil
   268  		}
   269  		switch attrs := attrs.(type) {
   270  		case *starlark.Dict:
   271  			for _, attr := range attrs.Keys() {
   272  				attrName, ok := attr.(starlark.String)
   273  				if !ok {
   274  					s.errorReporter("%q rule attr key is invalid type (want string, got %T)", s.name, attr)
   275  					continue
   276  				}
   277  				if attrValue, ok, err := attrs.Get(attrName); ok && err == nil {
   278  					switch t := attrValue.(type) {
   279  					case *starlark.Bool:
   280  						r.SetAttr(attrName.GoString(), bool(*t))
   281  					case *starlark.Int:
   282  						intValue, _ := t.Int64()
   283  						r.SetAttr(attrName.GoString(), intValue)
   284  					case starlark.String:
   285  						r.SetAttr(attrName.GoString(), t.GoString())
   286  					case *starlark.List:
   287  						s.reporter(thread, fmt.Sprintf("!!! %q rule attr %q is a list", s.name, attrName.GoString()))
   288  						r.SetAttr(attrName.GoString(), stringSlice(t, s.errorReporter))
   289  					default:
   290  						s.errorReporter("%q rule attr value is invalid type (want bool, int, string, list, got %T)", s.name, t)
   291  					}
   292  				}
   293  			}
   294  		default:
   295  			s.errorReporter("%q rule.attrs returned invalid type: %T", s.name, value)
   296  		}
   297  		return r
   298  	default:
   299  		s.errorReporter("provider %q rule() returned invalid type: %T", s.name, value)
   300  		return nil
   301  	}
   302  }
   303  
   304  // Resolve implements part of the RuleProvider interface.
   305  func (s *starlarkRuleProvider) Resolve(c *config.Config, ix *resolve.RuleIndex, r *rule.Rule, imports []string, from label.Label) {
   306  	if s.experimentalResolveDepsAttr != "" {
   307  		if r.Attr(s.experimentalResolveDepsAttr) != nil {
   308  			ResolveDepsAttr(s.experimentalResolveDepsAttr, false)(c, ix, r, imports, from)
   309  		}
   310  	}
   311  }
   312  
   313  // Imports implements part of the RuleProvider interface.
   314  func (s *starlarkRuleProvider) Imports(c *config.Config, r *rule.Rule, file *rule.File) []resolve.ImportSpec {
   315  	if s.experimentalResolveDepsAttr != "" {
   316  		if lib, ok := r.PrivateAttr(ProtoLibraryKey).(ProtoLibrary); ok {
   317  			return ProtoLibraryImportSpecsForKind(r.Kind(), lib)
   318  		}
   319  	}
   320  	return nil
   321  }
   322  
   323  func newLanguageRuleConfigStruct(rc *LanguageRuleConfig) *starlarkstruct.Struct {
   324  	if rc == nil {
   325  		return starlarkstruct.FromStringDict(
   326  			Symbol("LanguageRuleConfig"),
   327  			starlark.StringDict{
   328  				"attrs":          &starlark.Dict{},
   329  				"config":         newConfigStruct(&config.Config{}),
   330  				"deps":           &starlark.List{},
   331  				"enabled":        starlark.Bool(false),
   332  				"implementation": starlark.String(""),
   333  				"name":           starlark.String(""),
   334  				"options":        &starlark.List{},
   335  				"visibility":     &starlark.List{},
   336  			},
   337  		)
   338  	}
   339  	return starlarkstruct.FromStringDict(
   340  		Symbol("LanguageRuleConfig"),
   341  		starlark.StringDict{
   342  			"attrs":          newStringListDict(rc.Attrs),
   343  			"config":         newConfigStruct(rc.Config),
   344  			"deps":           newStringList(rc.GetDeps()),
   345  			"enabled":        starlark.Bool(rc.Enabled),
   346  			"implementation": starlark.String(rc.Implementation),
   347  			"name":           starlark.String(rc.Name),
   348  			"options":        newStringList(rc.GetOptions()),
   349  			"visibility":     newStringList(rc.GetVisibility()),
   350  		},
   351  	)
   352  }
   353  
   354  func newProtocConfigurationStruct(pc *ProtocConfiguration) *starlarkstruct.Struct {
   355  	if pc == nil {
   356  		return starlarkstruct.FromStringDict(
   357  			Symbol("ProtocConfiguration"),
   358  			starlark.StringDict{
   359  				"package_config":  newPackageConfigStruct(nil),
   360  				"language_config": newLanguageConfigStruct(nil),
   361  				"library":         starlark.None,
   362  				"rel":             starlark.String(""),
   363  				"prefix":          starlark.String(""),
   364  				"outputs":         newStringList([]string{}),
   365  				"imports":         newStringList([]string{}),
   366  				"mappings":        &starlark.Dict{},
   367  				"plugins":         &starlark.List{},
   368  			},
   369  		)
   370  	}
   371  	return starlarkstruct.FromStringDict(
   372  		Symbol("ProtocConfiguration"),
   373  		starlark.StringDict{
   374  			"package_config":  newPackageConfigStruct(pc.PackageConfig),
   375  			"language_config": newLanguageConfigStruct(pc.LanguageConfig),
   376  			"proto_library":   newProtoLibraryStruct(pc.Library),
   377  			"rel":             starlark.String(pc.Rel),
   378  			"prefix":          starlark.String(pc.Prefix),
   379  			"outputs":         newStringList(pc.Outputs),
   380  			"imports":         newStringList(pc.Imports),
   381  			"mappings":        newStringStringDict(pc.Mappings),
   382  			"plugins":         newPluginConfigurationList(pc.Plugins),
   383  		},
   384  	)
   385  }
   386  
   387  func newLanguageConfigStruct(lc *LanguageConfig) *starlarkstruct.Struct {
   388  	if lc == nil {
   389  		return starlarkstruct.FromStringDict(
   390  			Symbol("LanguageConfig"),
   391  			starlark.StringDict{
   392  				"name":    starlark.String(""),
   393  				"protoc":  starlark.String(""),
   394  				"enabled": starlark.Bool(false),
   395  				"plugins": &starlark.Dict{},
   396  				"rules":   &starlark.Dict{},
   397  			},
   398  		)
   399  	}
   400  	return starlarkstruct.FromStringDict(
   401  		Symbol("LanguageConfig"),
   402  		starlark.StringDict{
   403  			"name":    starlark.String(lc.Name),
   404  			"protoc":  starlark.String(lc.Protoc),
   405  			"enabled": starlark.Bool(lc.Enabled),
   406  			"plugins": newStringBoolDict(lc.Plugins),
   407  			"rules":   newStringBoolDict(lc.Rules),
   408  		},
   409  	)
   410  }
   411  
   412  func newPluginConfigurationList(plugins []*PluginConfiguration) *starlark.List {
   413  	list := make([]starlark.Value, len(plugins))
   414  	for i, p := range plugins {
   415  		list[i] = newPluginConfigurationStruct(*p)
   416  	}
   417  	return starlark.NewList(list)
   418  }
   419  
   420  func newPluginConfigurationStruct(p PluginConfiguration) *starlarkstruct.Struct {
   421  	return starlarkstruct.FromStringDict(
   422  		Symbol("PluginConfiguration"),
   423  		starlark.StringDict{
   424  			"config":   newLanguagePluginConfigStruct(*p.Config),
   425  			"label":    starlark.String(p.Label.String()),
   426  			"out":      starlark.String(p.Out),
   427  			"options":  newStringList(p.Options),
   428  			"outputs":  newStringList(p.Outputs),
   429  			"mappings": newStringStringDict(p.Mappings),
   430  		},
   431  	)
   432  }