github.com/stackb/rules_proto@v0.0.0-20240221195024-5428336c51f1/pkg/language/protobuf/override.go (about)

     1  package protobuf
     2  
     3  import (
     4  	"log"
     5  
     6  	"github.com/bazelbuild/bazel-gazelle/config"
     7  	"github.com/bazelbuild/bazel-gazelle/label"
     8  	"github.com/bazelbuild/bazel-gazelle/rule"
     9  
    10  	"github.com/stackb/rules_proto/pkg/protoc"
    11  )
    12  
    13  const (
    14  	// protoOverrideRulesKey is used to stash a list of proto_library rules in a
    15  	// private attr for later deps resolution.
    16  	protoLibrariesRuleKey = "_proto_library_rules"
    17  	// overrideKindName is the name of the kind
    18  	overrideKindName = "proto_library_override"
    19  	// debugOverrides is a developer-flag.
    20  	debugOverrides = false
    21  )
    22  
    23  var overrideKind = rule.KindInfo{
    24  	ResolveAttrs: map[string]bool{"deps": true},
    25  }
    26  
    27  func makeProtoOverrideRule(libs []protoc.ProtoLibrary) *rule.Rule {
    28  	// This rule is *only* used to trigger a Resolve() callback such that we can
    29  	// process the proto_library rules we've captured here; the rule itself is
    30  	// always deleted from the file.
    31  	overrideRule := rule.NewRule(overrideKindName, protoLibrariesRuleKey)
    32  	overrideRule.SetPrivateAttr(protoLibrariesRuleKey, libs)
    33  	return overrideRule
    34  }
    35  
    36  func resolveOverrideRule(c *config.Config, rel string, overrideRule *rule.Rule, resolver protoc.ImportResolver) {
    37  
    38  	libs := overrideRule.PrivateAttr(protoLibrariesRuleKey).([]protoc.ProtoLibrary)
    39  	if len(libs) == 0 {
    40  		return
    41  	}
    42  
    43  	for _, lib := range libs {
    44  		r := lib.Rule()
    45  
    46  		// re-resolve dependencies.
    47  		deps := make([]label.Label, 0)
    48  
    49  		imports := r.PrivateAttr(config.GazelleImportsKey)
    50  		if imps, ok := imports.([]string); ok {
    51  			for _, imp := range imps {
    52  				result := resolver.Resolve("proto", "proto", imp)
    53  				if len(result) > 0 {
    54  					first := result[0]
    55  					deps = append(deps, first.Label)
    56  					if debugOverrides {
    57  						log.Println("go_googleapis resolve imports HIT", imp, first.Label)
    58  					}
    59  				} else {
    60  					if debugOverrides {
    61  						log.Printf("go_googleapis resolve imports MISS %s: %+v", imp, resolver)
    62  					}
    63  				}
    64  			}
    65  		}
    66  
    67  		if len(deps) > 0 {
    68  			ss := make([]string, len(deps))
    69  			for i, lbl := range deps {
    70  				ss[i] = lbl.Rel("", rel).String()
    71  			}
    72  			r.SetAttr("deps", protoc.DeduplicateAndSort(ss))
    73  		}
    74  	}
    75  
    76  	overrideRule.Delete()
    77  }