go.uber.org/yarpc@v1.72.1/internal/protoplugin-v2/registry.go (about) 1 // Copyright (c) 2022 Uber Technologies, Inc. 2 // 3 // Permission is hereby granted, free of charge, to any person obtaining a copy 4 // of this software and associated documentation files (the "Software"), to deal 5 // in the Software without restriction, including without limitation the rights 6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 // copies of the Software, and to permit persons to whom the Software is 8 // furnished to do so, subject to the following conditions: 9 // 10 // The above copyright notice and this permission notice shall be included in 11 // all copies or substantial portions of the Software. 12 // 13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 // THE SOFTWARE. 20 21 package protopluginv2 22 23 import ( 24 "fmt" 25 "path" 26 "path/filepath" 27 "strings" 28 29 "github.com/golang/protobuf/protoc-gen-go/descriptor" 30 "github.com/golang/protobuf/protoc-gen-go/plugin" 31 ) 32 33 type registry struct { 34 // msgs is a mapping from fully-qualified message name to descriptor 35 msgs map[string]*Message 36 // enums is a mapping from fully-qualified enum name to descriptor 37 enums map[string]*Enum 38 // files is a mapping from file path to descriptor 39 files map[string]*File 40 // prefix is a prefix to be inserted to golang package paths generated from proto package names. 41 prefix string 42 // pkgMap is a user-specified mapping from file path to proto package. 43 pkgMap map[string]string 44 // pkgAliases is a mapping from package aliases to package paths in go which are already taken. 45 pkgAliases map[string]string 46 } 47 48 func newRegistry() *registry { 49 return ®istry{ 50 msgs: make(map[string]*Message), 51 enums: make(map[string]*Enum), 52 files: make(map[string]*File), 53 pkgMap: make(map[string]string), 54 pkgAliases: make(map[string]string), 55 } 56 } 57 58 func (r *registry) Load(req *plugin_go.CodeGeneratorRequest) error { 59 for _, file := range req.GetProtoFile() { 60 r.loadFile(file) 61 } 62 var targetPkg string 63 for _, name := range req.FileToGenerate { 64 target := r.files[name] 65 if target == nil { 66 return fmt.Errorf("no such file: %s", name) 67 } 68 name := packageIdentityName(target.FileDescriptorProto) 69 if targetPkg == "" { 70 targetPkg = name 71 } else { 72 if targetPkg != name { 73 return fmt.Errorf("inconsistent package names: %s %s", targetPkg, name) 74 } 75 } 76 if err := r.loadServices(target); err != nil { 77 return err 78 } 79 if err := r.loadTransitiveFileDependencies(target); err != nil { 80 return err 81 } 82 } 83 return nil 84 } 85 86 func (r *registry) LookupMessage(location string, name string) (*Message, error) { 87 if strings.HasPrefix(name, ".") { 88 m, ok := r.msgs[name] 89 if !ok { 90 return nil, fmt.Errorf("no message found: %s", name) 91 } 92 return m, nil 93 } 94 95 if !strings.HasPrefix(location, ".") { 96 location = fmt.Sprintf(".%s", location) 97 } 98 components := strings.Split(location, ".") 99 for len(components) > 0 { 100 fqmn := strings.Join(append(components, name), ".") 101 if m, ok := r.msgs[fqmn]; ok { 102 return m, nil 103 } 104 components = components[:len(components)-1] 105 } 106 return nil, fmt.Errorf("no message found: %s", name) 107 } 108 109 func (r *registry) LookupFile(name string) (*File, error) { 110 f, ok := r.files[name] 111 if !ok { 112 return nil, fmt.Errorf("no such file given: %s", name) 113 } 114 return f, nil 115 } 116 117 func (r *registry) AddPackageMap(file, protoPackage string) { 118 r.pkgMap[file] = protoPackage 119 } 120 121 func (r *registry) SetPrefix(prefix string) { 122 r.prefix = prefix 123 } 124 125 func (r *registry) ReserveGoPackageAlias(alias, pkgpath string) error { 126 if taken, ok := r.pkgAliases[alias]; ok { 127 if taken == pkgpath { 128 return nil 129 } 130 return fmt.Errorf("package name %s is already taken. Use another alias", alias) 131 } 132 r.pkgAliases[alias] = pkgpath 133 return nil 134 } 135 136 // loadFile loads messages, enumerations and fields from "file". 137 // It does not loads services and methods in "file". You need to call 138 // loadServices after loadFiles is called for all files to load services and methods. 139 func (r *registry) loadFile(file *descriptor.FileDescriptorProto) { 140 pkg := &GoPackage{ 141 Path: r.goPackagePath(file), 142 Name: defaultGoPackageName(file), 143 } 144 if err := r.ReserveGoPackageAlias(pkg.Name, pkg.Path); err != nil { 145 for i := 0; ; i++ { 146 alias := fmt.Sprintf("%s_%d", pkg.Name, i) 147 if err := r.ReserveGoPackageAlias(alias, pkg.Path); err == nil { 148 pkg.Alias = alias 149 break 150 } 151 } 152 } 153 f := &File{ 154 FileDescriptorProto: file, 155 GoPackage: pkg, 156 } 157 r.files[file.GetName()] = f 158 r.registerMsg(f, nil, file.GetMessageType()) 159 r.registerEnum(f, nil, file.GetEnumType()) 160 } 161 162 func (r *registry) registerMsg(file *File, outerPath []string, msgs []*descriptor.DescriptorProto) { 163 for i, md := range msgs { 164 m := &Message{ 165 DescriptorProto: md, 166 File: file, 167 Outers: outerPath, 168 Index: i, 169 } 170 for _, fd := range md.GetField() { 171 m.Fields = append(m.Fields, &Field{ 172 FieldDescriptorProto: fd, 173 Message: m, 174 }) 175 } 176 file.Messages = append(file.Messages, m) 177 r.msgs[m.FQMN()] = m 178 179 var outers []string 180 outers = append(outers, outerPath...) 181 outers = append(outers, m.GetName()) 182 r.registerMsg(file, outers, m.GetNestedType()) 183 r.registerEnum(file, outers, m.GetEnumType()) 184 } 185 } 186 187 func (r *registry) registerEnum(file *File, outerPath []string, enums []*descriptor.EnumDescriptorProto) { 188 for i, ed := range enums { 189 e := &Enum{ 190 EnumDescriptorProto: ed, 191 File: file, 192 Outers: outerPath, 193 Index: i, 194 } 195 file.Enums = append(file.Enums, e) 196 r.enums[e.FQEN()] = e 197 } 198 } 199 200 // goPackagePath returns the go package path which go files generated from "f" should have. 201 // It respects the mapping registered by AddPkgMap if exists. Or use go_package as import path 202 // if it includes a slash, Otherwide, it generates a path from the file name of "f". 203 func (r *registry) goPackagePath(f *descriptor.FileDescriptorProto) string { 204 name := f.GetName() 205 if pkg, ok := r.pkgMap[name]; ok { 206 return path.Join(r.prefix, pkg) 207 } 208 gopkg := f.Options.GetGoPackage() 209 idx := strings.LastIndex(gopkg, "/") 210 if idx >= 0 { 211 return gopkg 212 } 213 return path.Join(r.prefix, path.Dir(name)) 214 } 215 216 // loadServices registers services and their methods from "targetFile" to "r". 217 // It must be called after loadFile is called for all files so that loadServices 218 // can resolve names of message types and their fields. 219 func (r *registry) loadServices(file *File) error { 220 var svcs []*Service 221 for _, sd := range file.GetService() { 222 svc := &Service{ 223 ServiceDescriptorProto: sd, 224 File: file, 225 } 226 for _, md := range sd.GetMethod() { 227 meth, err := r.newMethod(svc, md) 228 if err != nil { 229 return err 230 } 231 svc.Methods = append(svc.Methods, meth) 232 } 233 if len(svc.Methods) == 0 { 234 continue 235 } 236 svcs = append(svcs, svc) 237 } 238 file.Services = svcs 239 return nil 240 } 241 242 func (r *registry) newMethod(svc *Service, md *descriptor.MethodDescriptorProto) (*Method, error) { 243 requestType, err := r.LookupMessage(svc.File.GetPackage(), md.GetInputType()) 244 if err != nil { 245 return nil, err 246 } 247 responseType, err := r.LookupMessage(svc.File.GetPackage(), md.GetOutputType()) 248 if err != nil { 249 return nil, err 250 } 251 return &Method{ 252 MethodDescriptorProto: md, 253 Service: svc, 254 RequestType: requestType, 255 ResponseType: responseType, 256 }, nil 257 } 258 259 // loadTransitiveFileDependencies registers services and their methods from "targetFile" to "r". 260 // It must be called after loadFile is called for all files so that loadTransitiveFileDependencies 261 // can resolve file descriptors as depdendencies. 262 func (r *registry) loadTransitiveFileDependencies(file *File) error { 263 seen := make(map[string]struct{}) 264 files, err := r.loadTransitiveFileDependenciesRecurse(file, seen) 265 if err != nil { 266 return err 267 } 268 file.TransitiveDependencies = files 269 return nil 270 } 271 272 func (r *registry) loadTransitiveFileDependenciesRecurse(file *File, seen map[string]struct{}) ([]*File, error) { 273 seen[file.GetName()] = struct{}{} 274 var deps []*File 275 for _, fname := range file.GetDependency() { 276 if _, ok := seen[fname]; ok { 277 continue 278 } 279 f, err := r.LookupFile(fname) 280 if err != nil { 281 return nil, err 282 } 283 deps = append(deps, f) 284 285 files, err := r.loadTransitiveFileDependenciesRecurse(f, seen) 286 if err != nil { 287 return nil, err 288 } 289 deps = append(deps, files...) 290 } 291 return deps, nil 292 } 293 294 // defaultGoPackageName returns the default go package name to be used for go files generated from "f". 295 // You might need to use an unique alias for the package when you import it. Use ReserveGoPackageAlias to get a unique alias. 296 func defaultGoPackageName(f *descriptor.FileDescriptorProto) string { 297 name := packageIdentityName(f) 298 return strings.Replace(name, ".", "_", -1) 299 } 300 301 // packageIdentityName returns the identity of packages. 302 // protoc-gen-grpc-gateway rejects CodeGenerationRequests which contains more than one packages 303 // as protoc-gen-go does. 304 func packageIdentityName(f *descriptor.FileDescriptorProto) string { 305 if f.Options != nil && f.Options.GoPackage != nil { 306 gopkg := f.Options.GetGoPackage() 307 // if go_package specifies an alias in the form of full/path/package;alias, use alias over package 308 idx := strings.Index(gopkg, ";") 309 if idx >= 0 { 310 return gopkg[idx+1:] 311 } 312 idx = strings.LastIndex(gopkg, "/") 313 if idx < 0 { 314 return gopkg 315 } 316 317 return gopkg[idx+1:] 318 } 319 320 if f.Package == nil { 321 base := filepath.Base(f.GetName()) 322 ext := filepath.Ext(base) 323 return strings.TrimSuffix(base, ext) 324 } 325 return f.GetPackage() 326 }