github.com/Big-big-orange/protoreflect@v0.0.0-20240408141420-285cedfdf6a4/desc/builder/resolver.go (about) 1 package builder 2 3 import ( 4 "fmt" 5 "reflect" 6 "strings" 7 8 "google.golang.org/protobuf/encoding/protowire" 9 "google.golang.org/protobuf/proto" 10 "google.golang.org/protobuf/reflect/protoreflect" 11 "google.golang.org/protobuf/reflect/protoregistry" 12 "google.golang.org/protobuf/types/descriptorpb" 13 14 "github.com/Big-big-orange/protoreflect/desc" 15 "github.com/Big-big-orange/protoreflect/desc/internal" 16 "github.com/Big-big-orange/protoreflect/dynamic" 17 ) 18 19 type dependencies struct { 20 descs map[*desc.FileDescriptor]struct{} 21 res protoregistry.Types 22 } 23 24 func newDependencies() *dependencies { 25 return &dependencies{ 26 descs: map[*desc.FileDescriptor]struct{}{}, 27 } 28 } 29 30 func (d *dependencies) add(fd *desc.FileDescriptor) { 31 if _, ok := d.descs[fd]; ok { 32 // already added 33 return 34 } 35 d.descs[fd] = struct{}{} 36 internal.RegisterExtensionsFromImportedFile(&d.res, fd.UnwrapFile()) 37 } 38 39 // dependencyResolver is the work-horse for converting a tree of builders into a 40 // tree of descriptors. It scans a root (usually a file builder) and recursively 41 // resolves all dependencies (references to builders in other trees as well as 42 // references to other already-built descriptors). The result of resolution is a 43 // file descriptor (or an error). 44 type dependencyResolver struct { 45 resolvedRoots map[Builder]*desc.FileDescriptor 46 seen map[Builder]struct{} 47 opts BuilderOptions 48 } 49 50 func newResolver(opts BuilderOptions) *dependencyResolver { 51 return &dependencyResolver{ 52 resolvedRoots: map[Builder]*desc.FileDescriptor{}, 53 seen: map[Builder]struct{}{}, 54 opts: opts, 55 } 56 } 57 58 func (r *dependencyResolver) resolveElement(b Builder, seen []Builder) (*desc.FileDescriptor, error) { 59 b = getRoot(b) 60 61 if fd, ok := r.resolvedRoots[b]; ok { 62 return fd, nil 63 } 64 65 for _, s := range seen { 66 if s == b { 67 names := make([]string, len(seen)+1) 68 for i, s := range seen { 69 names[i] = s.GetName() 70 } 71 names[len(seen)] = b.GetName() 72 return nil, fmt.Errorf("descriptors have cyclic dependency: %s", strings.Join(names, " -> ")) 73 } 74 } 75 seen = append(seen, b) 76 77 var fd *desc.FileDescriptor 78 var err error 79 switch b := b.(type) { 80 case *FileBuilder: 81 fd, err = r.resolveFile(b, b, seen) 82 default: 83 fd, err = r.resolveSyntheticFile(b, seen) 84 } 85 if err != nil { 86 return nil, err 87 } 88 r.resolvedRoots[b] = fd 89 return fd, nil 90 } 91 92 func (r *dependencyResolver) resolveFile(fb *FileBuilder, root Builder, seen []Builder) (*desc.FileDescriptor, error) { 93 deps := newDependencies() 94 // add explicit imports first 95 for fd := range fb.explicitImports { 96 deps.add(fd) 97 } 98 for dep := range fb.explicitDeps { 99 if dep == fb { 100 // ignore erroneous self references 101 continue 102 } 103 fd, err := r.resolveElement(dep, seen) 104 if err != nil { 105 return nil, err 106 } 107 deps.add(fd) 108 } 109 // now accumulate implicit dependencies based on other types referenced 110 for _, mb := range fb.messages { 111 if err := r.resolveTypesInMessage(root, seen, deps, mb); err != nil { 112 return nil, err 113 } 114 } 115 for _, exb := range fb.extensions { 116 if err := r.resolveTypesInExtension(root, seen, deps, exb); err != nil { 117 return nil, err 118 } 119 } 120 for _, sb := range fb.services { 121 if err := r.resolveTypesInService(root, seen, deps, sb); err != nil { 122 return nil, err 123 } 124 } 125 126 // finally, resolve custom options (which may refer to deps already 127 // computed above) 128 if err := r.resolveTypesInFileOptions(root, deps, fb); err != nil { 129 return nil, err 130 } 131 132 depSlice := make([]*desc.FileDescriptor, 0, len(deps.descs)) 133 depMap := make(map[string]*desc.FileDescriptor, len(deps.descs)) 134 for dep := range deps.descs { 135 isDuplicate, err := isDuplicateDependency(dep, depMap) 136 if err != nil { 137 return nil, err 138 } 139 if !isDuplicate { 140 depSlice = append(depSlice, dep) 141 } 142 } 143 144 fp, err := fb.buildProto(depSlice) 145 if err != nil { 146 return nil, err 147 } 148 149 // make sure this file name doesn't collide with any of its dependencies 150 fileNames := map[string]struct{}{} 151 for _, d := range depSlice { 152 addFileNames(d, fileNames) 153 } 154 unique := makeUnique(fp.GetName(), fileNames) 155 if unique != fp.GetName() { 156 fp.Name = proto.String(unique) 157 } 158 159 return desc.CreateFileDescriptor(fp, depSlice...) 160 } 161 162 // isDuplicateDependency checks for duplicate descriptors 163 func isDuplicateDependency(dep *desc.FileDescriptor, depMap map[string]*desc.FileDescriptor) (bool, error) { 164 if _, exists := depMap[dep.GetName()]; !exists { 165 depMap[dep.GetName()] = dep 166 return false, nil 167 } 168 prevFDP := depMap[dep.GetName()].AsFileDescriptorProto() 169 depFDP := dep.AsFileDescriptorProto() 170 171 // temporarily reset Source Code Fields as builders does not have SourceCodeInfo 172 defer setSourceCodeInfo(prevFDP, nil)() 173 defer setSourceCodeInfo(depFDP, nil)() 174 175 if !proto.Equal(prevFDP, depFDP) { 176 return true, fmt.Errorf("multiple versions of descriptors found with same file name: %s", dep.GetName()) 177 } 178 return true, nil 179 } 180 181 func setSourceCodeInfo(fdp *descriptorpb.FileDescriptorProto, info *descriptorpb.SourceCodeInfo) (reset func()) { 182 prevSourceCodeInfo := fdp.SourceCodeInfo 183 fdp.SourceCodeInfo = info 184 return func() { fdp.SourceCodeInfo = prevSourceCodeInfo } 185 } 186 187 func addFileNames(fd *desc.FileDescriptor, files map[string]struct{}) { 188 if _, ok := files[fd.GetName()]; ok { 189 // already added 190 return 191 } 192 files[fd.GetName()] = struct{}{} 193 for _, d := range fd.GetDependencies() { 194 addFileNames(d, files) 195 } 196 } 197 198 func (r *dependencyResolver) resolveSyntheticFile(b Builder, seen []Builder) (*desc.FileDescriptor, error) { 199 // find ancestor to temporarily attach to new file 200 curr := b 201 for curr.GetParent() != nil { 202 curr = curr.GetParent() 203 } 204 f := NewFile("") 205 switch curr := curr.(type) { 206 case *MessageBuilder: 207 f.messages = append(f.messages, curr) 208 case *EnumBuilder: 209 f.enums = append(f.enums, curr) 210 case *ServiceBuilder: 211 f.services = append(f.services, curr) 212 case *FieldBuilder: 213 if curr.IsExtension() { 214 f.extensions = append(f.extensions, curr) 215 } else { 216 panic("field must be added to message before calling Build()") 217 } 218 case *OneOfBuilder: 219 if _, ok := b.(*OneOfBuilder); ok { 220 panic("one-of must be added to message before calling Build()") 221 } else { 222 // b was a child of one-of which means it must have been a field 223 panic("field must be added to message before calling Build()") 224 } 225 case *MethodBuilder: 226 panic("method must be added to service before calling Build()") 227 case *EnumValueBuilder: 228 panic("enum value must be added to enum before calling Build()") 229 default: 230 panic(fmt.Sprintf("Unrecognized kind of builder: %T", b)) 231 } 232 curr.setParent(f) 233 234 // don't forget to reset when done 235 defer func() { 236 curr.setParent(nil) 237 }() 238 239 return r.resolveFile(f, b, seen) 240 } 241 242 func (r *dependencyResolver) resolveTypesInMessage(root Builder, seen []Builder, deps *dependencies, mb *MessageBuilder) error { 243 for _, b := range mb.fieldsAndOneOfs { 244 if flb, ok := b.(*FieldBuilder); ok { 245 if err := r.resolveTypesInField(root, seen, deps, flb); err != nil { 246 return err 247 } 248 } else { 249 oob := b.(*OneOfBuilder) 250 for _, flb := range oob.choices { 251 if err := r.resolveTypesInField(root, seen, deps, flb); err != nil { 252 return err 253 } 254 } 255 } 256 } 257 for _, nmb := range mb.nestedMessages { 258 if err := r.resolveTypesInMessage(root, seen, deps, nmb); err != nil { 259 return err 260 } 261 } 262 for _, exb := range mb.nestedExtensions { 263 if err := r.resolveTypesInExtension(root, seen, deps, exb); err != nil { 264 return err 265 } 266 } 267 return nil 268 } 269 270 func (r *dependencyResolver) resolveTypesInExtension(root Builder, seen []Builder, deps *dependencies, exb *FieldBuilder) error { 271 if err := r.resolveTypesInField(root, seen, deps, exb); err != nil { 272 return err 273 } 274 if exb.foreignExtendee != nil { 275 deps.add(exb.foreignExtendee.GetFile()) 276 } else if err := r.resolveType(root, seen, exb.localExtendee, deps); err != nil { 277 return err 278 } 279 return nil 280 } 281 282 func (r *dependencyResolver) resolveTypesInService(root Builder, seen []Builder, deps *dependencies, sb *ServiceBuilder) error { 283 for _, mtb := range sb.methods { 284 if err := r.resolveRpcType(root, seen, mtb.ReqType, deps); err != nil { 285 return err 286 } 287 if err := r.resolveRpcType(root, seen, mtb.RespType, deps); err != nil { 288 return err 289 } 290 } 291 return nil 292 } 293 294 func (r *dependencyResolver) resolveRpcType(root Builder, seen []Builder, t *RpcType, deps *dependencies) error { 295 if t.foreignType != nil { 296 deps.add(t.foreignType.GetFile()) 297 } else { 298 return r.resolveType(root, seen, t.localType, deps) 299 } 300 return nil 301 } 302 303 func (r *dependencyResolver) resolveTypesInField(root Builder, seen []Builder, deps *dependencies, flb *FieldBuilder) error { 304 if flb.fieldType.foreignMsgType != nil { 305 deps.add(flb.fieldType.foreignMsgType.GetFile()) 306 } else if flb.fieldType.foreignEnumType != nil { 307 deps.add(flb.fieldType.foreignEnumType.GetFile()) 308 } else if flb.fieldType.localMsgType != nil { 309 if flb.fieldType.localMsgType == flb.msgType { 310 return r.resolveTypesInMessage(root, seen, deps, flb.msgType) 311 } else { 312 return r.resolveType(root, seen, flb.fieldType.localMsgType, deps) 313 } 314 } else if flb.fieldType.localEnumType != nil { 315 return r.resolveType(root, seen, flb.fieldType.localEnumType, deps) 316 } 317 return nil 318 } 319 320 func (r *dependencyResolver) resolveType(root Builder, seen []Builder, typeBuilder Builder, deps *dependencies) error { 321 otherRoot := getRoot(typeBuilder) 322 if root == otherRoot { 323 // local reference, so it will get resolved when we finish resolving this root 324 return nil 325 } 326 fd, err := r.resolveElement(otherRoot, seen) 327 if err != nil { 328 return err 329 } 330 deps.add(fd) 331 return nil 332 } 333 334 func (r *dependencyResolver) resolveTypesInFileOptions(root Builder, deps *dependencies, fb *FileBuilder) error { 335 for _, mb := range fb.messages { 336 if err := r.resolveTypesInMessageOptions(root, fb.origExts, deps, mb); err != nil { 337 return err 338 } 339 } 340 for _, eb := range fb.enums { 341 if err := r.resolveTypesInEnumOptions(root, fb.origExts, deps, eb); err != nil { 342 return err 343 } 344 } 345 for _, exb := range fb.extensions { 346 if err := r.resolveTypesInOptions(root, fb.origExts, deps, exb.Options); err != nil { 347 return err 348 } 349 } 350 for _, sb := range fb.services { 351 for _, mtb := range sb.methods { 352 if err := r.resolveTypesInOptions(root, fb.origExts, deps, mtb.Options); err != nil { 353 return err 354 } 355 } 356 if err := r.resolveTypesInOptions(root, fb.origExts, deps, sb.Options); err != nil { 357 return err 358 } 359 } 360 return r.resolveTypesInOptions(root, fb.origExts, deps, fb.Options) 361 } 362 363 func (r *dependencyResolver) resolveTypesInMessageOptions(root Builder, fileExts *dynamic.ExtensionRegistry, deps *dependencies, mb *MessageBuilder) error { 364 for _, b := range mb.fieldsAndOneOfs { 365 if flb, ok := b.(*FieldBuilder); ok { 366 if err := r.resolveTypesInOptions(root, fileExts, deps, flb.Options); err != nil { 367 return err 368 } 369 } else { 370 oob := b.(*OneOfBuilder) 371 for _, flb := range oob.choices { 372 if err := r.resolveTypesInOptions(root, fileExts, deps, flb.Options); err != nil { 373 return err 374 } 375 } 376 if err := r.resolveTypesInOptions(root, fileExts, deps, oob.Options); err != nil { 377 return err 378 } 379 } 380 } 381 for _, extr := range mb.ExtensionRanges { 382 if err := r.resolveTypesInOptions(root, fileExts, deps, extr.Options); err != nil { 383 return err 384 } 385 } 386 for _, eb := range mb.nestedEnums { 387 if err := r.resolveTypesInEnumOptions(root, fileExts, deps, eb); err != nil { 388 return err 389 } 390 } 391 for _, nmb := range mb.nestedMessages { 392 if err := r.resolveTypesInMessageOptions(root, fileExts, deps, nmb); err != nil { 393 return err 394 } 395 } 396 for _, exb := range mb.nestedExtensions { 397 if err := r.resolveTypesInOptions(root, fileExts, deps, exb.Options); err != nil { 398 return err 399 } 400 } 401 if err := r.resolveTypesInOptions(root, fileExts, deps, mb.Options); err != nil { 402 return err 403 } 404 return nil 405 } 406 407 func (r *dependencyResolver) resolveTypesInEnumOptions(root Builder, fileExts *dynamic.ExtensionRegistry, deps *dependencies, eb *EnumBuilder) error { 408 for _, evb := range eb.values { 409 if err := r.resolveTypesInOptions(root, fileExts, deps, evb.Options); err != nil { 410 return err 411 } 412 } 413 if err := r.resolveTypesInOptions(root, fileExts, deps, eb.Options); err != nil { 414 return err 415 } 416 return nil 417 } 418 419 func (r *dependencyResolver) resolveTypesInOptions(root Builder, fileExts *dynamic.ExtensionRegistry, deps *dependencies, opts proto.Message) error { 420 // nothing to see if opts is nil 421 if opts == nil { 422 return nil 423 } 424 if rv := reflect.ValueOf(opts); rv.Kind() == reflect.Ptr && rv.IsNil() { 425 return nil 426 } 427 428 ref := opts.ProtoReflect() 429 tags := map[int32]protoreflect.ExtensionType{} 430 proto.RangeExtensions(opts, func(xt protoreflect.ExtensionType, _ interface{}) bool { 431 num := int32(xt.TypeDescriptor().Number()) 432 tags[num] = xt 433 return true 434 }) 435 436 unk := ref.GetUnknown() 437 for len(unk) > 0 { 438 v, n := protowire.ConsumeVarint(unk) 439 if n < 0 { 440 break 441 } 442 unk = unk[n:] 443 444 num, t := protowire.DecodeTag(v) 445 if _, ok := tags[int32(num)]; !ok { 446 tags[int32(num)] = nil 447 } 448 449 switch t { 450 case protowire.VarintType: 451 _, n = protowire.ConsumeVarint(unk) 452 case protowire.Fixed64Type: 453 _, n = protowire.ConsumeFixed64(unk) 454 case protowire.BytesType: 455 _, n = protowire.ConsumeBytes(unk) 456 case protowire.StartGroupType: 457 _, n = protowire.ConsumeGroup(num, unk) 458 case protowire.EndGroupType: 459 // invalid encoding 460 break 461 case protowire.Fixed32Type: 462 _, n = protowire.ConsumeFixed32(unk) 463 } 464 if n < 0 { 465 break 466 } 467 unk = unk[n:] 468 } 469 470 msgName := string(proto.MessageName(opts)) 471 for tag, xt := range tags { 472 // see if known dependencies have this option 473 if _, err := deps.res.FindExtensionByNumber(protoreflect.FullName(msgName), protoreflect.FieldNumber(tag)); err == nil { 474 // yep! nothing else to do 475 continue 476 } 477 // see if this extension is defined in *this* builder 478 if findExtension(root, msgName, tag) { 479 // yep! 480 continue 481 } 482 // see if configured extension registry knows about it 483 if extd := r.opts.Extensions.FindExtension(msgName, tag); extd != nil { 484 // extension registry recognized it! 485 deps.add(extd.GetFile()) 486 continue 487 } 488 // see if given file extensions knows about it 489 if fileExts != nil { 490 extd := fileExts.FindExtension(msgName, tag) 491 if extd != nil { 492 // file extensions recognized it! 493 deps.add(extd.GetFile()) 494 continue 495 } 496 } 497 498 if xt != nil { 499 // known extension? add its file to builder's deps 500 fd, err := desc.WrapFile(xt.TypeDescriptor().ParentFile()) 501 if err != nil { 502 return err 503 } 504 deps.add(fd) 505 continue 506 } 507 508 if r.opts.RequireInterpretedOptions { 509 // we require options to be interpreted but are not able to! 510 return fmt.Errorf("could not interpret custom option for %s, tag %d", msgName, tag) 511 } 512 } 513 return nil 514 } 515 516 func findExtension(b Builder, messageName string, extTag int32) bool { 517 if fb, ok := b.(*FileBuilder); ok && findExtensionInFile(fb, messageName, extTag) { 518 return true 519 } 520 if mb, ok := b.(*MessageBuilder); ok && findExtensionInMessage(mb, messageName, extTag) { 521 return true 522 } 523 return false 524 } 525 526 func findExtensionInFile(fb *FileBuilder, messageName string, extTag int32) bool { 527 for _, extb := range fb.extensions { 528 if extb.GetExtendeeTypeName() == messageName && extb.number == extTag { 529 return true 530 } 531 } 532 for _, mb := range fb.messages { 533 if findExtensionInMessage(mb, messageName, extTag) { 534 return true 535 } 536 } 537 return false 538 } 539 540 func findExtensionInMessage(mb *MessageBuilder, messageName string, extTag int32) bool { 541 for _, extb := range mb.nestedExtensions { 542 if extb.GetExtendeeTypeName() == messageName && extb.number == extTag { 543 return true 544 } 545 } 546 for _, mb := range mb.nestedMessages { 547 if findExtensionInMessage(mb, messageName, extTag) { 548 return true 549 } 550 } 551 return false 552 }