github.com/stackb/rules_proto@v0.0.0-20240221195024-5428336c51f1/pkg/plugin/golang/protobuf/protoc-gen-go.go (about)

     1  package protobuf
     2  
     3  import (
     4  	"container/list"
     5  	"path"
     6  	"strings"
     7  
     8  	"github.com/bazelbuild/bazel-gazelle/label"
     9  	"github.com/bazelbuild/bazel-gazelle/rule"
    10  
    11  	"github.com/stackb/rules_proto/pkg/protoc"
    12  )
    13  
    14  // TransitiveImportMappingsKey stores a map[string]string on the library
    15  const TransitiveImportMappingsKey = "_transitive_importmappings"
    16  
    17  const ProtocGenGoPluginName = "golang:protobuf:protoc-gen-go"
    18  
    19  func init() {
    20  	protoc.Plugins().MustRegisterPlugin(&ProtocGenGoPlugin{})
    21  }
    22  
    23  // ProtocGenGoPlugin implements Plugin for the the gogo_* family of plugins.
    24  type ProtocGenGoPlugin struct{}
    25  
    26  // Name implements part of the Plugin interface.
    27  func (p *ProtocGenGoPlugin) Name() string {
    28  	return ProtocGenGoPluginName
    29  }
    30  
    31  // Configure implements part of the Plugin interface.
    32  func (p *ProtocGenGoPlugin) Configure(ctx *protoc.PluginContext) *protoc.PluginConfiguration {
    33  	if !p.shouldApply(ctx.ProtoLibrary) {
    34  		return nil
    35  	}
    36  	mappings, _ := GetImportMappings(ctx.PluginConfig.GetOptions())
    37  
    38  	// record M associations now.
    39  	//
    40  	// TODO(pcj): where and when is the optimal time to do this?  protoc-gen-go,
    41  	// protoc-gen-gogo, and protoc-gen-go-grpc all use this.  Perhaps they
    42  	// should *all* perform it, just to be sure?
    43  	for k, v := range mappings {
    44  		// "option" is used as the name since we cannot leave that part of the
    45  		// label empty.
    46  		protoc.GlobalResolver().Provide("proto", "M", k, label.New("", v, "option")) // FIXME(pcj): should this not be config.RepoName?
    47  	}
    48  
    49  	return &protoc.PluginConfiguration{
    50  		Label:   label.New("build_stack_rules_proto", "plugin/golang/protobuf", "protoc-gen-go"),
    51  		Outputs: p.outputs(ctx.ProtoLibrary, mappings),
    52  		Options: ctx.PluginConfig.GetOptions(),
    53  	}
    54  }
    55  
    56  func (p *ProtocGenGoPlugin) ResolvePluginOptions(cfg *protoc.PluginConfiguration, r *rule.Rule, from label.Label) []string {
    57  	return ResolvePluginOptionsTransitive(cfg, r, from)
    58  }
    59  
    60  func ResolvePluginOptionsTransitive(cfg *protoc.PluginConfiguration, r *rule.Rule, from label.Label) []string {
    61  	transitiveMappings := ResolveTransitiveImportMappings(r, from)
    62  
    63  	options := make([]string, 0)
    64  
    65  	for _, opt := range cfg.Options {
    66  		if !strings.HasPrefix(opt, "M") {
    67  			options = append(options, opt)
    68  			continue
    69  		}
    70  
    71  		parts := strings.SplitN(opt[1:], "=", 2)
    72  		if len(parts) != 2 {
    73  			options = append(options, opt)
    74  			continue
    75  		}
    76  
    77  		imp := parts[0]
    78  		if _, ok := transitiveMappings[imp]; ok {
    79  			options = append(options, opt)
    80  			continue
    81  		}
    82  
    83  		// if we get here, the M option is not in the set of transitives for
    84  		// this rule, so leave it out.
    85  	}
    86  
    87  	return options
    88  }
    89  
    90  func ResolveTransitiveImportMappings(r *rule.Rule, from label.Label) map[string]string {
    91  	lib := r.PrivateAttr(protoc.ProtoLibraryKey)
    92  	if lib == nil {
    93  		return nil
    94  	}
    95  	library := lib.(protoc.ProtoLibrary)
    96  	libRule := library.Rule()
    97  
    98  	// already created?
    99  	if transitiveMappings, ok := libRule.PrivateAttr(TransitiveImportMappingsKey).(map[string]string); ok {
   100  		return transitiveMappings
   101  	}
   102  
   103  	// nope.
   104  	transitiveMappings := make(map[string]string)
   105  	resolver := protoc.GlobalResolver()
   106  
   107  	seen := make(map[string]bool)
   108  	stack := list.New()
   109  	for _, src := range library.Srcs() {
   110  		stack.PushBack(path.Join(from.Pkg, src))
   111  	}
   112  
   113  	// for every source file in the proto library, gather the list of source
   114  	// files on which it depends, until there are no more unprocessed sources.
   115  	// Foreach one check if there is an importmapping for it and record the
   116  	// association.
   117  	for {
   118  		if stack.Len() == 0 {
   119  			break
   120  		}
   121  		current := stack.Front()
   122  		stack.Remove(current)
   123  
   124  		protofile := current.Value.(string)
   125  		if seen[protofile] {
   126  			continue
   127  		}
   128  		seen[protofile] = true
   129  
   130  		depends := resolver.Resolve("proto", "depends", protofile)
   131  		for _, dep := range depends {
   132  			stack.PushBack(path.Join(dep.Label.Pkg, dep.Label.Name))
   133  		}
   134  
   135  		mappings := resolver.Resolve("proto", "M", protofile)
   136  		if len(mappings) > 0 {
   137  			first := mappings[0]
   138  			transitiveMappings[protofile] = path.Join(first.Label.Pkg)
   139  		}
   140  	}
   141  
   142  	libRule.SetPrivateAttr(TransitiveImportMappingsKey, transitiveMappings)
   143  
   144  	return transitiveMappings
   145  }
   146  
   147  func (p *ProtocGenGoPlugin) shouldApply(lib protoc.ProtoLibrary) bool {
   148  	for _, f := range lib.Files() {
   149  		if f.HasMessages() || f.HasEnums() {
   150  			return true
   151  		}
   152  	}
   153  	return false
   154  }
   155  
   156  func (p *ProtocGenGoPlugin) outputs(lib protoc.ProtoLibrary, importMappings map[string]string) []string {
   157  	srcs := make([]string, 0)
   158  	for _, f := range lib.Files() {
   159  		if !(f.HasMessages() || f.HasEnums()) {
   160  			continue
   161  		}
   162  		srcs = append(srcs, GetGoOutputBaseName(f, importMappings)+".pb.go")
   163  	}
   164  	return srcs
   165  }
   166  
   167  func GetGoOutputBaseName(f *protoc.File, importMappings map[string]string) string {
   168  	base := f.Name
   169  	pkg := f.Package()
   170  	// see https://github.com/gogo/protobuf/blob/master/protoc-gen-gogo/generator/generator.go#L347
   171  	if mapping := importMappings[path.Join(f.Dir, f.Basename)]; mapping != "" {
   172  		base = path.Join(mapping, base)
   173  	} else if goPackage, _, ok := protoc.GoPackageOption(f.Options()); ok {
   174  		base = path.Join(goPackage, base)
   175  	} else if pkg.Name != "" {
   176  		base = path.Join(strings.ReplaceAll(pkg.Name, ".", "/"), base)
   177  	}
   178  	return base
   179  }
   180  
   181  func GetImportMappings(options []string) (map[string]string, []string) {
   182  	// gather options that look like protoc-gen-go "importmapping" (M) options
   183  	// (e.g Mfoo.proto=github.com/example/foo).
   184  	mappings := make(map[string]string)
   185  	rest := make([]string, 0)
   186  
   187  	for _, opt := range options {
   188  		if !strings.HasPrefix(opt, "M") {
   189  			rest = append(rest, opt)
   190  			continue
   191  		}
   192  		parts := strings.SplitN(opt[1:], "=", 2)
   193  		if len(parts) != 2 {
   194  			rest = append(rest, opt)
   195  			continue
   196  		}
   197  		mappings[parts[0]] = parts[1]
   198  	}
   199  
   200  	return mappings, rest
   201  }