github.com/stackb/rules_proto@v0.0.0-20240221195024-5428336c51f1/pkg/protoc/resolver.go (about)

     1  package protoc
     2  
     3  import (
     4  	"bufio"
     5  	"fmt"
     6  	"io"
     7  	"log"
     8  	"os"
     9  	"reflect"
    10  	"sort"
    11  	"strings"
    12  	"unsafe"
    13  
    14  	"github.com/bazelbuild/bazel-gazelle/config"
    15  	"github.com/bazelbuild/bazel-gazelle/label"
    16  	"github.com/bazelbuild/bazel-gazelle/resolve"
    17  )
    18  
    19  // globalImportResolver is the default resolver singleton.
    20  var globalImportResolver = NewImportResolver(&ImportResolverOptions{
    21  	Debug:  false,
    22  	Printf: log.Printf,
    23  }).(*resolver)
    24  
    25  const (
    26  	// ResolveProvidesKey is the key expected to store a string slice that
    27  	// informs what imports a rule provides.
    28  	ResolveProvidesKey = "_resolve_provides"
    29  )
    30  
    31  type ImportResolver interface {
    32  	// Resolve returns any previously provided labels associated with the given
    33  	// kind and import.
    34  	Resolve(lang, impLang, imp string) []resolve.FindResult
    35  	// Provide records the association between a rule kind+attr, the value of
    36  	// the attr, and the label that provides the value.
    37  	Provide(lang string, impLang, imp string, location label.Label)
    38  	// Imports takes a callback and iterates all known imports for the given
    39  	// lang/impLang. True should be returned to continue iteration.
    40  	Imports(lang, impLang string, visitor func(imp string, location []label.Label) bool)
    41  }
    42  
    43  // ImportProvider is an entity that can be queried for imported symbols by
    44  // language name.
    45  type ImportProvider interface {
    46  	// Provided returns all known provided symbols, by label
    47  	Provided(lang, impLang string) map[label.Label][]string
    48  }
    49  
    50  // ImportCrossResolver handles dependency resolution.
    51  type ImportCrossResolver interface {
    52  	resolve.CrossResolver
    53  	ImportResolver
    54  	ImportProvider
    55  
    56  	// LoadFile loads csv file to populate the resolver
    57  	LoadFile(filename string) error
    58  	// SaveFile writes a csv file
    59  	SaveFile(filename, repoName string) error
    60  	// Install adds configured resolve entries into the resolve config.
    61  	Install(c *config.Config)
    62  }
    63  
    64  // GlobalResolver returns a reference to the global ImportResolver
    65  func GlobalResolver() ImportCrossResolver {
    66  	return globalImportResolver
    67  }
    68  
    69  type ImportResolverOptions struct {
    70  	Printf func(format string, args ...interface{})
    71  	Debug  bool
    72  }
    73  
    74  func NewImportResolver(options *ImportResolverOptions) ImportResolver {
    75  	return &resolver{
    76  		known:   make(map[string]importLabels),
    77  		options: options,
    78  	}
    79  }
    80  
    81  // importLabels records which labels are associated with a given proto import
    82  // statement.
    83  type importLabels map[string][]label.Label
    84  
    85  // resolver implements ImportResolver.
    86  type resolver struct {
    87  	options *ImportResolverOptions
    88  	// known is a mapping between lang and importLabel map
    89  	known map[string]importLabels
    90  }
    91  
    92  // LoadFile reads a protoresolve csv file.
    93  func (r *resolver) LoadFile(filename string) error {
    94  	f, err := os.Open(filename)
    95  	if err != nil {
    96  		return err
    97  	}
    98  	defer f.Close()
    99  	return r.Load(f)
   100  }
   101  
   102  // Load reads input and returns a list of items.  Comment lines beginning
   103  // with '#' are ignored.
   104  func (r *resolver) Load(in io.Reader) error {
   105  	scanner := bufio.NewScanner(in)
   106  	for scanner.Scan() {
   107  		line := scanner.Text()
   108  		if strings.HasPrefix(line, "#") {
   109  			continue
   110  		}
   111  		parts := strings.SplitN(line, ",", 4)
   112  		if len(parts) != 4 {
   113  			log.Printf("warn: parse %q, expected 4 items, got %d", line, len(parts))
   114  			continue
   115  		}
   116  		lang := parts[0]
   117  		impLang := parts[1]
   118  		imp := parts[2]
   119  		lbl, err := label.Parse(parts[3])
   120  		if err != nil {
   121  			return fmt.Errorf("malformed label at position 4 in %s: %v", line, err)
   122  		}
   123  		r.Provide(lang, impLang, imp, lbl)
   124  	}
   125  	return nil
   126  }
   127  
   128  func (r *resolver) Save(out io.Writer, repoName string) {
   129  	keys := make([]string, 0)
   130  	for k := range r.known {
   131  		keys = append(keys, k)
   132  	}
   133  	sort.Strings(keys)
   134  	for _, key := range keys {
   135  		imports := r.known[key]
   136  		imps := make([]string, 0)
   137  		for imp := range imports {
   138  			imps = append(imps, imp)
   139  		}
   140  		sort.Strings(imps)
   141  		lang, impLang := keyLang(key)
   142  		for _, imp := range imps {
   143  			labels := imports[imp]
   144  			for _, lbl := range labels {
   145  				// skip external labels, these represent externally loaded
   146  				// entries and we don't write transitive resolves
   147  				if lbl.Repo != "" {
   148  					continue
   149  				}
   150  				l := label.New(repoName, lbl.Pkg, lbl.Name)
   151  				fmt.Fprintf(out, "%s,%s,%s,%s\n", lang, impLang, imp, l)
   152  			}
   153  		}
   154  	}
   155  }
   156  
   157  func (r *resolver) SaveFile(filename, repoName string) error {
   158  	f, err := os.Create(filename)
   159  	if err != nil {
   160  		return fmt.Errorf("save imports file: %w", err)
   161  	}
   162  
   163  	fmt.Fprintf(f, "# GENERATED FILE, DO NOT EDIT (created by gazelle)\n")
   164  	fmt.Fprintf(f, "# lang,imp.lang,imp,label\n")
   165  
   166  	r.Save(f, repoName)
   167  	if err := f.Close(); err != nil {
   168  		return err
   169  	}
   170  
   171  	// log.Println("Wrote resolve file:", filename)
   172  	return nil
   173  }
   174  
   175  // CrossResolve provides dependency resolution logic for the protobuf language extension.
   176  func (r *resolver) CrossResolve(c *config.Config, ix *resolve.RuleIndex, imp resolve.ImportSpec, lang string) []resolve.FindResult {
   177  	res := r.Resolve(lang, imp.Lang, imp.Imp)
   178  	if r.options.Debug {
   179  		r.options.Printf("cross-resolve %s %s %s (%d results)", lang, imp.Lang, imp.Imp, len(res))
   180  	}
   181  	return res
   182  }
   183  
   184  // Resolve implements part of the ImportResolver interface.
   185  func (r *resolver) Resolve(lang, impLang, imp string) []resolve.FindResult {
   186  	key := langKey(lang, impLang)
   187  	known := r.known[key]
   188  	if known == nil {
   189  		known = r.known[lang]
   190  	}
   191  	if known == nil {
   192  		if r.options.Debug {
   193  			r.options.Printf("resolve miss %s: no records under language %q", imp, key)
   194  		}
   195  		return nil
   196  	}
   197  	if got, ok := known[imp]; ok {
   198  		res := make([]resolve.FindResult, len(got))
   199  		for i, l := range got {
   200  			res[i] = resolve.FindResult{Label: l}
   201  		}
   202  		// reverse results to preserve last-wins semantics of prior
   203  		// stackb/rules_proto behavior
   204  		for i, j := 0, len(res)-1; i < j; i, j = i+1, j-1 {
   205  			res[i], res[j] = res[j], res[i]
   206  		}
   207  		return res
   208  	}
   209  	return nil
   210  }
   211  
   212  // Provide implements part of the ImportResolver interface.
   213  func (r *resolver) Provide(lang, impLang, imp string, from label.Label) {
   214  	key := langKey(lang, impLang)
   215  	known, ok := r.known[key]
   216  	if !ok {
   217  		known = make(map[string][]label.Label)
   218  		r.known[key] = known
   219  	}
   220  	for _, v := range known[imp] {
   221  		if v == from {
   222  			return
   223  		}
   224  	}
   225  	if r.options.Debug {
   226  		r.options.Printf("resolver %v provides %s %s", from, key, imp)
   227  	}
   228  	known[imp] = append(known[imp], from)
   229  }
   230  
   231  // Provided implements the ImportProvider interface.
   232  func (r *resolver) Provided(lang, impLang string) map[label.Label][]string {
   233  	if len(r.known) == 0 {
   234  		return nil
   235  	}
   236  	result := make(map[label.Label][]string)
   237  	key := langKey(lang, impLang)
   238  	known := r.known[key]
   239  	if known == nil {
   240  		known = r.known[lang]
   241  	}
   242  	if known == nil {
   243  		return nil
   244  	}
   245  	for imp, ll := range known {
   246  		for _, l := range ll {
   247  			result[l] = append(result[l], imp)
   248  		}
   249  	}
   250  	return result
   251  }
   252  
   253  // Imports implements part of the ImportResolver interface.
   254  func (r *resolver) Imports(lang, impLang string, visitor func(imp string, location []label.Label) bool) {
   255  	key := langKey(lang, impLang)
   256  	known := r.known[key]
   257  	if known == nil {
   258  		known = r.known[lang]
   259  	}
   260  	if known == nil {
   261  		return
   262  	}
   263  	for k, v := range known {
   264  		if !visitor(k, v) {
   265  			break
   266  		}
   267  	}
   268  }
   269  
   270  func (r *resolver) Install(c *config.Config) {
   271  	overrides := make(overrideSpec, 0)
   272  
   273  	for key, known := range r.known {
   274  		lang, impLang := keyLang(key)
   275  		for imp, lbls := range known {
   276  			for _, lbl := range lbls {
   277  				overrides[overrideKey{
   278  					imp:  resolve.ImportSpec{Lang: impLang, Imp: imp},
   279  					lang: lang,
   280  				}] = lbl
   281  			}
   282  		}
   283  	}
   284  
   285  	if len(overrides) == 0 {
   286  		return
   287  	}
   288  
   289  	rewriteResolveConfigOverrides(getResolveConfig(c), overrides)
   290  }
   291  
   292  // ResolveImports is a utility function that returns a matching list of labels
   293  // for the given import list.
   294  func ResolveImports(resolver ImportResolver, lang, impLang string, imports []string) []label.Label {
   295  	deps := make([]label.Label, 0)
   296  	for _, imp := range imports {
   297  		result := resolver.Resolve(lang, impLang, imp)
   298  		if len(result) > 0 {
   299  			first := result[0]
   300  			deps = append(deps, first.Label)
   301  			// if r.options.Debug {
   302  			// 	r.options.Printf(lang, imp, "HIT", first.Label)
   303  			// } else {
   304  			// 	if r.options.Debug {
   305  			// 		r.options.Printf(lang, imp, "MISS", resolver)
   306  			// 	}
   307  		}
   308  	}
   309  	return deps
   310  }
   311  
   312  // ResolveImportsString is a utility function that returns a matching list of labels
   313  // for the given import list.
   314  func ResolveImportsString(resolver ImportResolver, rel, kind, impLang string, imports []string) []string {
   315  	match := ResolveImports(resolver, kind, impLang, imports)
   316  	deps := make([]string, len(match))
   317  	for i, l := range match {
   318  		deps[i] = l.Rel("", rel).String()
   319  	}
   320  	return deps
   321  }
   322  
   323  // getResolveConfig returns the resolve.resolveConfig
   324  func getResolveConfig(c *config.Config) interface{} {
   325  	return c.Exts["_resolve"]
   326  }
   327  
   328  // rewriteResolveConfigOverrides reads the existing private attribute and
   329  // appends more overrides.
   330  func rewriteResolveConfigOverrides(rc interface{}, more overrideSpec) {
   331  	rcv := reflect.ValueOf(rc).Elem()
   332  	val := reflect.Indirect(rcv)
   333  	member := val.FieldByName("overrides")
   334  	ptrToOverrides := unsafe.Pointer(member.UnsafeAddr())
   335  	overrides := (*overrideSpec)(ptrToOverrides)
   336  
   337  	// create new array: FindRuleWithOverride searches last entries first, so
   338  	// respect the users own resolve directives by putting them last
   339  	newOverrides := make(overrideSpec, 0)
   340  	for k, v := range more {
   341  		newOverrides[k] = v
   342  	}
   343  	for k, v := range *overrides {
   344  		newOverrides[k] = v
   345  	}
   346  
   347  	// reassign memory value
   348  	*overrides = newOverrides
   349  }
   350  
   351  type overrideKey struct {
   352  	imp  resolve.ImportSpec
   353  	lang string
   354  }
   355  
   356  // overrideSpec is a copy of the same private type in resolve/config.go.  It must be
   357  // kept in sync with the original to avoid discrepancy with the expected memory
   358  // layout.
   359  //
   360  // NOTE: in https://github.com/bazelbuild/bazel-gazelle/pull/1687,
   361  // []overrideSpec was changed to map[overrideKey]label.Label
   362  type overrideSpec map[overrideKey]label.Label
   363  
   364  func langKey(lang, impLang string) string {
   365  	return lang + " " + impLang
   366  }
   367  
   368  func keyLang(key string) (string, string) {
   369  	parts := strings.SplitN(key, " ", 2)
   370  	return parts[0], parts[1]
   371  }