github.com/jhump/protoreflect@v1.16.0/dynamic/msgregistry/ptype_resolver.go (about) 1 package msgregistry 2 3 import ( 4 "bytes" 5 "context" 6 "fmt" 7 "reflect" 8 "sort" 9 "strings" 10 "sync" 11 "sync/atomic" 12 13 "github.com/golang/protobuf/proto" 14 "google.golang.org/protobuf/types/descriptorpb" 15 "google.golang.org/protobuf/types/known/apipb" 16 "google.golang.org/protobuf/types/known/typepb" 17 "google.golang.org/protobuf/types/known/wrapperspb" 18 19 "github.com/jhump/protoreflect/desc" 20 "github.com/jhump/protoreflect/dynamic" 21 ) 22 23 var ( 24 enumOptionsDesc, enumValueOptionsDesc *desc.MessageDescriptor 25 msgOptionsDesc, fieldOptionsDesc *desc.MessageDescriptor 26 svcOptionsDesc, methodOptionsDesc *desc.MessageDescriptor 27 ) 28 29 func init() { 30 var err error 31 enumOptionsDesc, err = desc.LoadMessageDescriptorForMessage((*descriptorpb.EnumOptions)(nil)) 32 if err != nil { 33 panic("Failed to load descriptor for EnumOptions") 34 } 35 enumValueOptionsDesc, err = desc.LoadMessageDescriptorForMessage((*descriptorpb.EnumValueOptions)(nil)) 36 if err != nil { 37 panic("Failed to load descriptor for EnumValueOptions") 38 } 39 msgOptionsDesc, err = desc.LoadMessageDescriptorForMessage((*descriptorpb.MessageOptions)(nil)) 40 if err != nil { 41 panic("Failed to load descriptor for MessageOptions") 42 } 43 fieldOptionsDesc, err = desc.LoadMessageDescriptorForMessage((*descriptorpb.FieldOptions)(nil)) 44 if err != nil { 45 panic("Failed to load descriptor for FieldOptions") 46 } 47 svcOptionsDesc, err = desc.LoadMessageDescriptorForMessage((*descriptorpb.ServiceOptions)(nil)) 48 if err != nil { 49 panic("Failed to load descriptor for ServiceOptions") 50 } 51 methodOptionsDesc, err = desc.LoadMessageDescriptorForMessage((*descriptorpb.MethodOptions)(nil)) 52 if err != nil { 53 panic("Failed to load descriptor for MethodOptions") 54 } 55 } 56 57 func ensureScheme(url string) string { 58 pos := strings.Index(url, "://") 59 if pos < 0 { 60 return "https://" + url 61 } 62 return url 63 } 64 65 // typeResolver is used by MessageRegistry to resolve message types. It uses a given TypeFetcher 66 // to retrieve type definitions and caches resulting descriptor objects. 67 type typeResolver struct { 68 fetcher TypeFetcher 69 mr *MessageRegistry 70 mu sync.RWMutex 71 cache map[string]desc.Descriptor 72 } 73 74 // resolveUrlToMessageDescriptor returns a message descriptor that represents the type at the given URL. 75 func (r *typeResolver) resolveUrlToMessageDescriptor(url string) (*desc.MessageDescriptor, error) { 76 url = ensureScheme(url) 77 r.mu.RLock() 78 cached := r.cache[url] 79 r.mu.RUnlock() 80 if cached != nil { 81 if md, ok := cached.(*desc.MessageDescriptor); ok { 82 return md, nil 83 } else { 84 return nil, fmt.Errorf("type for URL %v is the wrong type: wanted message, got enum", url) 85 } 86 } 87 88 rc := newResolutionContext(r) 89 if err := rc.addType(url, false); err != nil { 90 return nil, err 91 } 92 93 var files map[string]*desc.FileDescriptor 94 files, err := rc.toFileDescriptors(r.mr) 95 if err != nil { 96 return nil, err 97 } 98 r.mu.Lock() 99 defer r.mu.Unlock() 100 var md *desc.MessageDescriptor 101 if len(rc.typeLocations) > 0 { 102 if r.cache == nil { 103 r.cache = map[string]desc.Descriptor{} 104 } 105 } 106 for typeUrl, fileName := range rc.typeLocations { 107 fd := files[fileName] 108 sym := fd.FindSymbol(typeName(typeUrl)) 109 r.cache[typeUrl] = sym 110 if url == typeUrl { 111 md = sym.(*desc.MessageDescriptor) 112 } 113 } 114 return md, nil 115 } 116 117 // resolveUrlsToMessageDescriptors returns a map of the given URLs to corresponding 118 // message descriptors that represent the types at those URLs. 119 func (r *typeResolver) resolveUrlsToMessageDescriptors(urls ...string) (map[string]*desc.MessageDescriptor, error) { 120 ret := map[string]*desc.MessageDescriptor{} 121 var unresolved []string 122 r.mu.RLock() 123 for _, u := range urls { 124 u = ensureScheme(u) 125 cached := r.cache[u] 126 if cached != nil { 127 if md, ok := cached.(*desc.MessageDescriptor); ok { 128 ret[u] = md 129 } else { 130 r.mu.RUnlock() 131 return nil, fmt.Errorf("type for URL %v is the wrong type: wanted message, got enum", u) 132 } 133 } else { 134 ret[u] = nil 135 unresolved = append(unresolved, u) 136 } 137 } 138 r.mu.RUnlock() 139 140 if len(unresolved) == 0 { 141 return ret, nil 142 } 143 144 rc := newResolutionContext(r) 145 for _, u := range unresolved { 146 if err := rc.addType(u, false); err != nil { 147 return nil, err 148 } 149 } 150 151 var files map[string]*desc.FileDescriptor 152 files, err := rc.toFileDescriptors(r.mr) 153 if err != nil { 154 return nil, err 155 } 156 r.mu.Lock() 157 defer r.mu.Unlock() 158 if len(rc.typeLocations) > 0 { 159 if r.cache == nil { 160 r.cache = map[string]desc.Descriptor{} 161 } 162 } 163 for typeUrl, fileName := range rc.typeLocations { 164 fd := files[fileName] 165 sym := fd.FindSymbol(typeName(typeUrl)) 166 r.cache[typeUrl] = sym 167 if _, ok := ret[typeUrl]; ok { 168 ret[typeUrl] = sym.(*desc.MessageDescriptor) 169 } 170 } 171 return ret, nil 172 } 173 174 // resolveUrlToEnumDescriptor returns an enum descriptor that represents the enum type at the given URL. 175 func (r *typeResolver) resolveUrlToEnumDescriptor(url string) (*desc.EnumDescriptor, error) { 176 url = ensureScheme(url) 177 r.mu.RLock() 178 cached := r.cache[url] 179 r.mu.RUnlock() 180 if cached != nil { 181 if ed, ok := cached.(*desc.EnumDescriptor); ok { 182 return ed, nil 183 } else { 184 return nil, fmt.Errorf("type for URL %v is the wrong type: wanted enum, got message", url) 185 } 186 } 187 188 rc := newResolutionContext(r) 189 if err := rc.addType(url, true); err != nil { 190 return nil, err 191 } 192 193 var files map[string]*desc.FileDescriptor 194 files, err := rc.toFileDescriptors(r.mr) 195 if err != nil { 196 return nil, err 197 } 198 r.mu.Lock() 199 defer r.mu.Unlock() 200 var ed *desc.EnumDescriptor 201 if len(rc.typeLocations) > 0 { 202 if r.cache == nil { 203 r.cache = map[string]desc.Descriptor{} 204 } 205 } 206 for typeUrl, fileName := range rc.typeLocations { 207 fd := files[fileName] 208 sym := fd.FindSymbol(typeName(typeUrl)) 209 r.cache[typeUrl] = sym 210 if url == typeUrl { 211 ed = sym.(*desc.EnumDescriptor) 212 } 213 } 214 return ed, nil 215 } 216 217 type tracker func(d desc.Descriptor) bool 218 219 func newNameTracker() tracker { 220 names := map[string]struct{}{} 221 return func(d desc.Descriptor) bool { 222 name := d.GetFullyQualifiedName() 223 if _, ok := names[name]; ok { 224 return false 225 } 226 names[name] = struct{}{} 227 return true 228 } 229 } 230 231 func addDescriptors(ref string, files map[string]*fileEntry, d desc.Descriptor, msgs map[string]*desc.MessageDescriptor, onAdd tracker) { 232 name := d.GetFullyQualifiedName() 233 234 fileName := d.GetFile().GetName() 235 if fileName != ref { 236 dependee := files[ref] 237 if dependee.deps == nil { 238 dependee.deps = map[string]struct{}{} 239 } 240 dependee.deps[fileName] = struct{}{} 241 } 242 243 if !onAdd(d) { 244 // already added this one 245 return 246 } 247 248 fe := files[fileName] 249 if fe == nil { 250 fe = &fileEntry{} 251 fe.proto3 = d.GetFile().IsProto3() 252 files[fileName] = fe 253 } 254 fe.types.addType(name, d.AsProto()) 255 256 if md, ok := d.(*desc.MessageDescriptor); ok { 257 for _, fld := range md.GetFields() { 258 if fld.GetType() == descriptorpb.FieldDescriptorProto_TYPE_MESSAGE || fld.GetType() == descriptorpb.FieldDescriptorProto_TYPE_GROUP { 259 // prefer descriptor in msgs map over what the field descriptor indicates 260 md := msgs[fld.GetMessageType().GetFullyQualifiedName()] 261 if md == nil { 262 md = fld.GetMessageType() 263 } 264 addDescriptors(fileName, files, md, msgs, onAdd) 265 } else if fld.GetType() == descriptorpb.FieldDescriptorProto_TYPE_ENUM { 266 addDescriptors(fileName, files, fld.GetEnumType(), msgs, onAdd) 267 } 268 } 269 } 270 } 271 272 // resolutionContext provides the state for a resolution operation, accumulating details about 273 // type descriptions and the files that contain them. 274 type resolutionContext struct { 275 // The context and cancel function, used to coordinate multiple goroutines when there are multiple 276 // type or enum descriptions to download. 277 ctx context.Context 278 cancel func() 279 res *typeResolver 280 281 mu sync.Mutex 282 // map of file names to details regarding the files' contents 283 files map[string]*fileEntry 284 // map of type URLs to the file name that defines them 285 typeLocations map[string]string 286 // count of source contexts that do not indicate a file name (used to generate unique file names 287 // when synthesizing file descriptors) 288 unknownCount int 289 } 290 291 func newResolutionContext(res *typeResolver) *resolutionContext { 292 ctx, cancel := context.WithCancel(context.Background()) 293 return &resolutionContext{ 294 ctx: ctx, 295 cancel: cancel, 296 res: res, 297 typeLocations: map[string]string{}, 298 files: map[string]*fileEntry{}, 299 } 300 } 301 302 // addType adds the type at the given URL to the context, using the given fetcher to download the type's 303 // description. This function will recursively add dependencies (e.g. types referenced by the given type's 304 // fields if it is a message type), fetching their type descriptions concurrently. 305 func (rc *resolutionContext) addType(url string, enum bool) error { 306 if err := rc.ctx.Err(); err != nil { 307 return err 308 } 309 310 m, err := rc.res.fetcher(url, enum) 311 if err != nil { 312 return err 313 } else if m == nil { 314 return fmt.Errorf("failed to locate type for %s", url) 315 } 316 317 if enum { 318 rc.recordEnum(url, m.(*typepb.Enum)) 319 return nil 320 } 321 322 // for messages, resolve dependencies in parallel 323 t := m.(*typepb.Type) 324 fe, fileName := rc.recordType(url, t) 325 if fe == nil { 326 // already resolved this one 327 return nil 328 } 329 330 var wg sync.WaitGroup 331 var failed int32 332 for _, f := range t.Fields { 333 if f.Kind == typepb.Field_TYPE_GROUP || f.Kind == typepb.Field_TYPE_MESSAGE || f.Kind == typepb.Field_TYPE_ENUM { 334 typeUrl := ensureScheme(f.TypeUrl) 335 kind := f.Kind 336 wg.Add(1) 337 go func() { 338 defer wg.Done() 339 // first check the registry for descriptors 340 var d desc.Descriptor 341 var innerErr error 342 if kind == typepb.Field_TYPE_ENUM { 343 var ed *desc.EnumDescriptor 344 ed, innerErr = rc.res.mr.getRegisteredEnumTypeByUrl(typeUrl) 345 if ed != nil { 346 d = ed 347 } 348 } else { 349 var md *desc.MessageDescriptor 350 md, innerErr = rc.res.mr.getRegisteredMessageTypeByUrl(typeUrl) 351 if md != nil { 352 d = md 353 } 354 } 355 356 if innerErr == nil { 357 if d != nil { 358 // found it! 359 rc.recordDescriptor(typeUrl, fileName, d) 360 } else { 361 // not in registry, so we have to recursively fetch 362 innerErr = rc.addType(typeUrl, kind == typepb.Field_TYPE_ENUM) 363 } 364 } 365 366 // We want the "real" error to ultimately propagate to root, not 367 // one of the resulting cancellations (from any concurrent goroutines 368 // working in the same resolution context). 369 if innerErr != nil && (rc.ctx.Err() == nil || innerErr != context.Canceled) { 370 if atomic.CompareAndSwapInt32(&failed, 0, 1) { 371 err = innerErr 372 } 373 rc.cancel() 374 } 375 }() 376 } 377 } 378 wg.Wait() 379 if err != nil { 380 return err 381 } 382 // double-check if context has been cancelled 383 if err = rc.ctx.Err(); err != nil { 384 return err 385 } 386 387 rc.mu.Lock() 388 defer rc.mu.Unlock() 389 390 for _, f := range t.Fields { 391 if f.Kind == typepb.Field_TYPE_GROUP || f.Kind == typepb.Field_TYPE_MESSAGE || f.Kind == typepb.Field_TYPE_ENUM { 392 typeUrl := ensureScheme(f.TypeUrl) 393 if fe.deps == nil { 394 fe.deps = map[string]struct{}{} 395 } 396 dep := rc.typeLocations[typeUrl] 397 if dep != fileName { 398 fe.deps[dep] = struct{}{} 399 } 400 } 401 } 402 return nil 403 } 404 405 func (rc *resolutionContext) recordEnum(url string, e *typepb.Enum) { 406 rc.mu.Lock() 407 defer rc.mu.Unlock() 408 409 var fileName string 410 if e.SourceContext != nil && e.SourceContext.FileName != "" { 411 fileName = e.SourceContext.FileName 412 } else { 413 fileName = fmt.Sprintf("--unknown--%d.proto", rc.unknownCount) 414 rc.unknownCount++ 415 } 416 rc.typeLocations[url] = fileName 417 418 fe := rc.files[fileName] 419 if fe == nil { 420 fe = &fileEntry{} 421 rc.files[fileName] = fe 422 } 423 fe.types.addType(e.Name, e) 424 if e.Syntax == typepb.Syntax_SYNTAX_PROTO3 { 425 fe.proto3 = true 426 } 427 } 428 429 func (rc *resolutionContext) recordType(url string, t *typepb.Type) (*fileEntry, string) { 430 rc.mu.Lock() 431 defer rc.mu.Unlock() 432 433 if _, ok := rc.typeLocations[url]; ok { 434 return nil, "" 435 } 436 437 var fileName string 438 if t.SourceContext != nil && t.SourceContext.FileName != "" { 439 fileName = t.SourceContext.FileName 440 } else { 441 fileName = fmt.Sprintf("--unknown--%d.proto", rc.unknownCount) 442 rc.unknownCount++ 443 } 444 rc.typeLocations[url] = fileName 445 446 fe := rc.files[fileName] 447 if fe == nil { 448 fe = &fileEntry{} 449 rc.files[fileName] = fe 450 } 451 fe.types.addType(t.Name, t) 452 if t.Syntax == typepb.Syntax_SYNTAX_PROTO3 { 453 fe.proto3 = true 454 } 455 456 return fe, fileName 457 } 458 459 func (rc *resolutionContext) recordDescriptor(url, ref string, d desc.Descriptor) { 460 rc.mu.Lock() 461 defer rc.mu.Unlock() 462 463 addDescriptors(ref, rc.files, d, nil, func(dsc desc.Descriptor) bool { 464 u := ensureScheme(rc.res.mr.ComputeUrl(dsc)) 465 if _, ok := rc.typeLocations[u]; ok { 466 // already seen this one 467 return false 468 } 469 fileName := dsc.GetFile().GetName() 470 rc.typeLocations[u] = fileName 471 if dsc == d { 472 // make sure we're also adding the actual URL reference used 473 rc.typeLocations[url] = fileName 474 } 475 return true 476 }) 477 } 478 479 // toFileDescriptors converts the information in the context into a map of file names to file descriptors. 480 func (rc *resolutionContext) toFileDescriptors(mr *MessageRegistry) (map[string]*desc.FileDescriptor, error) { 481 return toFileDescriptors(rc.files, func(tt *typeTrie, name string) (proto.Message, error) { 482 mdp, edp := tt.ptypeToDescriptor(name, mr) 483 if mdp != nil { 484 return mdp, nil 485 } else { 486 return edp, nil 487 } 488 }) 489 } 490 491 // converts a map of file entries into a map of file descriptors using the given function to convert 492 // each trie node into a descriptor proto. 493 func toFileDescriptors(files map[string]*fileEntry, trieFn func(*typeTrie, string) (proto.Message, error)) (map[string]*desc.FileDescriptor, error) { 494 fdps := map[string]*descriptorpb.FileDescriptorProto{} 495 for name, file := range files { 496 fdp, err := file.toFileDescriptor(name, trieFn) 497 if err != nil { 498 return nil, err 499 } 500 fdps[name] = fdp 501 } 502 fds := map[string]*desc.FileDescriptor{} 503 for name, fdp := range fdps { 504 if _, ok := fds[name]; ok { 505 continue 506 } 507 var err error 508 if fds[name], err = makeFileDesc(fdp, fds, fdps); err != nil { 509 return nil, err 510 } 511 } 512 return fds, nil 513 } 514 515 func makeFileDesc(fdp *descriptorpb.FileDescriptorProto, fds map[string]*desc.FileDescriptor, fdps map[string]*descriptorpb.FileDescriptorProto) (*desc.FileDescriptor, error) { 516 deps := make([]*desc.FileDescriptor, len(fdp.Dependency)) 517 for i, dep := range fdp.Dependency { 518 d := fds[dep] 519 if d == nil { 520 var err error 521 depFd := fdps[dep] 522 if depFd == nil { 523 return nil, fmt.Errorf("missing dependency: %s", dep) 524 } 525 d, err = makeFileDesc(depFd, fds, fdps) 526 if err != nil { 527 return nil, err 528 } 529 } 530 deps[i] = d 531 } 532 if fd, err := desc.CreateFileDescriptor(fdp, deps...); err != nil { 533 return nil, err 534 } else { 535 fds[fdp.GetName()] = fd 536 return fd, nil 537 } 538 } 539 540 // fileEntry represents the contents of a single file. 541 type fileEntry struct { 542 types typeTrie 543 deps map[string]struct{} 544 proto3 bool 545 } 546 547 // toFileDescriptor converts this file entry into a file descriptor proto. The given function 548 // is used to transform nodes in a typeTrie into message and/or enum descriptor protos. 549 func (fe *fileEntry) toFileDescriptor(name string, trieFn func(*typeTrie, string) (proto.Message, error)) (*descriptorpb.FileDescriptorProto, error) { 550 var pkg bytes.Buffer 551 tt := &fe.types 552 first := true 553 last := "" 554 for tt.typ == nil { 555 if last != "" { 556 if first { 557 first = false 558 } else { 559 pkg.WriteByte('.') 560 } 561 pkg.WriteString(last) 562 } 563 if len(tt.children) != 1 { 564 break 565 } 566 for last, tt = range tt.children { 567 } 568 } 569 fd := createFileDescriptor(name, pkg.String(), fe.proto3, fe.deps) 570 if tt.typ != nil { 571 pm, err := trieFn(tt, last) 572 if err != nil { 573 return nil, err 574 } 575 if mdp, ok := pm.(*descriptorpb.DescriptorProto); ok { 576 fd.MessageType = append(fd.MessageType, mdp) 577 } else if edp, ok := pm.(*descriptorpb.EnumDescriptorProto); ok { 578 fd.EnumType = append(fd.EnumType, edp) 579 } else { 580 sdp := pm.(*descriptorpb.ServiceDescriptorProto) 581 fd.Service = append(fd.Service, sdp) 582 } 583 } else { 584 for name, nested := range tt.children { 585 pm, err := trieFn(nested, name) 586 if err != nil { 587 return nil, err 588 } 589 if mdp, ok := pm.(*descriptorpb.DescriptorProto); ok { 590 fd.MessageType = append(fd.MessageType, mdp) 591 } else if edp, ok := pm.(*descriptorpb.EnumDescriptorProto); ok { 592 fd.EnumType = append(fd.EnumType, edp) 593 } else { 594 sdp := pm.(*descriptorpb.ServiceDescriptorProto) 595 fd.Service = append(fd.Service, sdp) 596 } 597 } 598 } 599 return fd, nil 600 } 601 602 // typeTrie is a prefix trie where each key component is part of a fully-qualified type name. So key components 603 // will either be package name components or element names. 604 type typeTrie struct { 605 // successor key components 606 children map[string]*typeTrie 607 // if non-nil, the element whose fully-qualified name is the path from the trie root to this node 608 typ proto.Message 609 } 610 611 // addType recursively adds an element to the trie. 612 func (t *typeTrie) addType(key string, typ proto.Message) { 613 if key == "" { 614 t.typ = typ 615 return 616 } 617 if t.children == nil { 618 t.children = map[string]*typeTrie{} 619 } 620 curr, rest := split(key) 621 child := t.children[curr] 622 if child == nil { 623 child = &typeTrie{} 624 t.children[curr] = child 625 } 626 child.addType(rest, typ) 627 } 628 629 // ptypeToDescriptor converts this level of the trie into a message or enum 630 // descriptor proto, requiring that the element stored in t.typ is a *ptype.Type 631 // or *ptype.Enum. If t.typ is nil, a placeholder message (with no fields) is 632 // returned that contains the trie's children as nested message and/or enum 633 // types. 634 // 635 // If the value in t.typ is already a *descriptor.DescriptorProto or a 636 // *descriptor.EnumDescriptorProto then it is returned as is. This function 637 // should not be used in type tries that may have service descriptors. That will 638 // result in a panic. 639 func (t *typeTrie) ptypeToDescriptor(name string, mr *MessageRegistry) (*descriptorpb.DescriptorProto, *descriptorpb.EnumDescriptorProto) { 640 switch typ := t.typ.(type) { 641 case *descriptorpb.EnumDescriptorProto: 642 return nil, typ 643 case *typepb.Enum: 644 return nil, createEnumDescriptor(typ, mr) 645 case *descriptorpb.DescriptorProto: 646 return typ, nil 647 default: 648 var msg *descriptorpb.DescriptorProto 649 if t.typ == nil { 650 msg = createIntermediateMessageDescriptor(name) 651 } else { 652 msg = createMessageDescriptor(t.typ.(*typepb.Type), mr) 653 } 654 // sort children for deterministic output 655 var keys []string 656 for k := range t.children { 657 keys = append(keys, k) 658 } 659 for _, name := range keys { 660 nested := t.children[name] 661 chMsg, chEnum := nested.ptypeToDescriptor(name, mr) 662 if chMsg != nil { 663 msg.NestedType = append(msg.NestedType, chMsg) 664 } 665 if chEnum != nil { 666 msg.EnumType = append(msg.EnumType, chEnum) 667 } 668 } 669 return msg, nil 670 } 671 } 672 673 // rewriteDescriptor converts this level of the trie into a new descriptor 674 // proto, requiring that the element stored in t.type is already a service, 675 // message, or enum descriptor proto. If this trie has children then t.typ must 676 // be a message descriptor proto. The returned descriptor proto is the same as 677 // .type but with possibly new nested elements to represent this trie node's 678 // children. 679 func (t *typeTrie) rewriteDescriptor(name string) (proto.Message, error) { 680 if len(t.children) == 0 && t.typ != nil { 681 if mdp, ok := t.typ.(*descriptorpb.DescriptorProto); ok { 682 if len(mdp.NestedType) == 0 && len(mdp.EnumType) == 0 { 683 return mdp, nil 684 } 685 mdp = proto.Clone(mdp).(*descriptorpb.DescriptorProto) 686 mdp.NestedType = nil 687 mdp.EnumType = nil 688 return mdp, nil 689 } 690 return t.typ, nil 691 } 692 var mdp *descriptorpb.DescriptorProto 693 if t.typ == nil { 694 mdp = createIntermediateMessageDescriptor(name) 695 } else { 696 mdp = t.typ.(*descriptorpb.DescriptorProto) 697 mdp = proto.Clone(mdp).(*descriptorpb.DescriptorProto) 698 mdp.NestedType = nil 699 mdp.EnumType = nil 700 } 701 // sort children for deterministic output 702 var keys []string 703 for k := range t.children { 704 keys = append(keys, k) 705 } 706 for _, n := range keys { 707 ch := t.children[n] 708 typ, err := ch.rewriteDescriptor(n) 709 if err != nil { 710 return nil, err 711 } 712 switch typ := typ.(type) { 713 case (*descriptorpb.DescriptorProto): 714 mdp.NestedType = append(mdp.NestedType, typ) 715 case (*descriptorpb.EnumDescriptorProto): 716 mdp.EnumType = append(mdp.EnumType, typ) 717 default: 718 // TODO: this should probably panic instead 719 return nil, fmt.Errorf("invalid descriptor trie: message cannot have child of type %v", reflect.TypeOf(typ)) 720 } 721 } 722 return mdp, nil 723 } 724 725 func split(s string) (string, string) { 726 pos := strings.Index(s, ".") 727 if pos >= 0 { 728 return s[:pos], s[pos+1:] 729 } else { 730 return s, "" 731 } 732 } 733 734 func createEnumDescriptor(e *typepb.Enum, mr *MessageRegistry) *descriptorpb.EnumDescriptorProto { 735 var opts *descriptorpb.EnumOptions 736 if len(e.Options) > 0 { 737 dopts := createOptions(e.Options, enumOptionsDesc, mr) 738 opts = &descriptorpb.EnumOptions{} 739 dopts.ConvertTo(opts) // ignore any error 740 } 741 742 var vals []*descriptorpb.EnumValueDescriptorProto 743 for _, v := range e.Enumvalue { 744 evd := createEnumValueDescriptor(v, mr) 745 vals = append(vals, evd) 746 } 747 748 return &descriptorpb.EnumDescriptorProto{ 749 Name: proto.String(base(e.Name)), 750 Options: opts, 751 Value: vals, 752 } 753 } 754 755 func createEnumValueDescriptor(v *typepb.EnumValue, mr *MessageRegistry) *descriptorpb.EnumValueDescriptorProto { 756 var opts *descriptorpb.EnumValueOptions 757 if len(v.Options) > 0 { 758 dopts := createOptions(v.Options, enumValueOptionsDesc, mr) 759 opts = &descriptorpb.EnumValueOptions{} 760 dopts.ConvertTo(opts) // ignore any error 761 } 762 763 return &descriptorpb.EnumValueDescriptorProto{ 764 Name: proto.String(v.Name), 765 Number: proto.Int32(v.Number), 766 Options: opts, 767 } 768 } 769 770 func createMessageDescriptor(m *typepb.Type, mr *MessageRegistry) *descriptorpb.DescriptorProto { 771 var opts *descriptorpb.MessageOptions 772 if len(m.Options) > 0 { 773 dopts := createOptions(m.Options, msgOptionsDesc, mr) 774 opts = &descriptorpb.MessageOptions{} 775 dopts.ConvertTo(opts) // ignore any error 776 } 777 778 var fields []*descriptorpb.FieldDescriptorProto 779 for _, f := range m.Fields { 780 fields = append(fields, createFieldDescriptor(f, mr)) 781 } 782 783 var oneOfs []*descriptorpb.OneofDescriptorProto 784 for _, o := range m.Oneofs { 785 oneOfs = append(oneOfs, &descriptorpb.OneofDescriptorProto{ 786 Name: proto.String(o), 787 }) 788 } 789 790 return &descriptorpb.DescriptorProto{ 791 Name: proto.String(base(m.Name)), 792 Options: opts, 793 Field: fields, 794 OneofDecl: oneOfs, 795 } 796 } 797 798 func createFieldDescriptor(f *typepb.Field, mr *MessageRegistry) *descriptorpb.FieldDescriptorProto { 799 var opts *descriptorpb.FieldOptions 800 if len(f.Options) > 0 { 801 dopts := createOptions(f.Options, fieldOptionsDesc, mr) 802 opts = &descriptorpb.FieldOptions{} 803 dopts.ConvertTo(opts) // ignore any error 804 } 805 if f.Packed { 806 if opts == nil { 807 opts = &descriptorpb.FieldOptions{Packed: proto.Bool(true)} 808 } else { 809 opts.Packed = proto.Bool(true) 810 } 811 } 812 813 var oneOf *int32 814 if f.OneofIndex > 0 { 815 oneOf = proto.Int32(f.OneofIndex - 1) 816 } 817 818 var typeName string 819 if f.Kind == typepb.Field_TYPE_GROUP || f.Kind == typepb.Field_TYPE_MESSAGE || f.Kind == typepb.Field_TYPE_ENUM { 820 pos := strings.LastIndex(f.TypeUrl, "/") 821 typeName = "." + f.TypeUrl[pos+1:] 822 } 823 824 var label descriptorpb.FieldDescriptorProto_Label 825 switch f.Cardinality { 826 case typepb.Field_CARDINALITY_OPTIONAL: 827 label = descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL 828 case typepb.Field_CARDINALITY_REPEATED: 829 label = descriptorpb.FieldDescriptorProto_LABEL_REPEATED 830 case typepb.Field_CARDINALITY_REQUIRED: 831 label = descriptorpb.FieldDescriptorProto_LABEL_REQUIRED 832 } 833 834 var typ descriptorpb.FieldDescriptorProto_Type 835 switch f.Kind { 836 case typepb.Field_TYPE_ENUM: 837 typ = descriptorpb.FieldDescriptorProto_TYPE_ENUM 838 case typepb.Field_TYPE_GROUP: 839 typ = descriptorpb.FieldDescriptorProto_TYPE_GROUP 840 case typepb.Field_TYPE_MESSAGE: 841 typ = descriptorpb.FieldDescriptorProto_TYPE_MESSAGE 842 case typepb.Field_TYPE_BYTES: 843 typ = descriptorpb.FieldDescriptorProto_TYPE_BYTES 844 case typepb.Field_TYPE_STRING: 845 typ = descriptorpb.FieldDescriptorProto_TYPE_STRING 846 case typepb.Field_TYPE_BOOL: 847 typ = descriptorpb.FieldDescriptorProto_TYPE_BOOL 848 case typepb.Field_TYPE_DOUBLE: 849 typ = descriptorpb.FieldDescriptorProto_TYPE_DOUBLE 850 case typepb.Field_TYPE_FLOAT: 851 typ = descriptorpb.FieldDescriptorProto_TYPE_FLOAT 852 case typepb.Field_TYPE_FIXED32: 853 typ = descriptorpb.FieldDescriptorProto_TYPE_FIXED32 854 case typepb.Field_TYPE_FIXED64: 855 typ = descriptorpb.FieldDescriptorProto_TYPE_FIXED64 856 case typepb.Field_TYPE_INT32: 857 typ = descriptorpb.FieldDescriptorProto_TYPE_INT32 858 case typepb.Field_TYPE_INT64: 859 typ = descriptorpb.FieldDescriptorProto_TYPE_INT64 860 case typepb.Field_TYPE_SFIXED32: 861 typ = descriptorpb.FieldDescriptorProto_TYPE_SFIXED32 862 case typepb.Field_TYPE_SFIXED64: 863 typ = descriptorpb.FieldDescriptorProto_TYPE_SFIXED64 864 case typepb.Field_TYPE_SINT32: 865 typ = descriptorpb.FieldDescriptorProto_TYPE_SINT32 866 case typepb.Field_TYPE_SINT64: 867 typ = descriptorpb.FieldDescriptorProto_TYPE_SINT64 868 case typepb.Field_TYPE_UINT32: 869 typ = descriptorpb.FieldDescriptorProto_TYPE_UINT32 870 case typepb.Field_TYPE_UINT64: 871 typ = descriptorpb.FieldDescriptorProto_TYPE_UINT64 872 } 873 var defaultVal *string 874 if f.DefaultValue != "" { 875 defaultVal = proto.String(f.DefaultValue) 876 } 877 return &descriptorpb.FieldDescriptorProto{ 878 Name: proto.String(f.Name), 879 Number: proto.Int32(f.Number), 880 DefaultValue: defaultVal, 881 JsonName: proto.String(f.JsonName), 882 OneofIndex: oneOf, 883 TypeName: proto.String(typeName), 884 Label: label.Enum(), 885 Type: typ.Enum(), 886 Options: opts, 887 } 888 } 889 890 func createServiceDescriptor(a *apipb.Api, mr *MessageRegistry) *descriptorpb.ServiceDescriptorProto { 891 var opts *descriptorpb.ServiceOptions 892 if len(a.Options) > 0 { 893 dopts := createOptions(a.Options, svcOptionsDesc, mr) 894 opts = &descriptorpb.ServiceOptions{} 895 dopts.ConvertTo(opts) // ignore any error 896 } 897 898 methods := make([]*descriptorpb.MethodDescriptorProto, len(a.Methods)) 899 for i, m := range a.Methods { 900 methods[i] = createMethodDescriptor(m, mr) 901 } 902 903 return &descriptorpb.ServiceDescriptorProto{ 904 Name: proto.String(base(a.Name)), 905 Method: methods, 906 Options: opts, 907 } 908 } 909 910 func createMethodDescriptor(m *apipb.Method, mr *MessageRegistry) *descriptorpb.MethodDescriptorProto { 911 var opts *descriptorpb.MethodOptions 912 if len(m.Options) > 0 { 913 dopts := createOptions(m.Options, methodOptionsDesc, mr) 914 opts = &descriptorpb.MethodOptions{} 915 dopts.ConvertTo(opts) // ignore any error 916 } 917 918 var reqType, respType string 919 pos := strings.LastIndex(m.RequestTypeUrl, "/") 920 reqType = "." + m.RequestTypeUrl[pos+1:] 921 pos = strings.LastIndex(m.ResponseTypeUrl, "/") 922 respType = "." + m.ResponseTypeUrl[pos+1:] 923 924 return &descriptorpb.MethodDescriptorProto{ 925 Name: proto.String(m.Name), 926 Options: opts, 927 ClientStreaming: proto.Bool(m.RequestStreaming), 928 ServerStreaming: proto.Bool(m.ResponseStreaming), 929 InputType: proto.String(reqType), 930 OutputType: proto.String(respType), 931 } 932 } 933 934 func createIntermediateMessageDescriptor(name string) *descriptorpb.DescriptorProto { 935 return &descriptorpb.DescriptorProto{ 936 Name: proto.String(name), 937 } 938 } 939 940 func createFileDescriptor(name, pkg string, proto3 bool, deps map[string]struct{}) *descriptorpb.FileDescriptorProto { 941 imports := make([]string, 0, len(deps)) 942 for k := range deps { 943 imports = append(imports, k) 944 } 945 sort.Strings(imports) 946 var syntax string 947 if proto3 { 948 syntax = "proto3" 949 } else { 950 syntax = "proto2" 951 } 952 return &descriptorpb.FileDescriptorProto{ 953 Name: proto.String(name), 954 Package: proto.String(pkg), 955 Syntax: proto.String(syntax), 956 Dependency: imports, 957 } 958 } 959 960 func createOptions(options []*typepb.Option, optionsDesc *desc.MessageDescriptor, mr *MessageRegistry) *dynamic.Message { 961 // these are created "best effort" so entries which are unresolvable 962 // (or seemingly invalid) are simply ignored... 963 dopts := mr.mf.NewDynamicMessage(optionsDesc) 964 for _, o := range options { 965 field := optionsDesc.FindFieldByName(o.Name) 966 if field == nil { 967 field = mr.er.FindExtensionByName(optionsDesc.GetFullyQualifiedName(), o.Name) 968 if field == nil && o.Name[0] != '[' { 969 field = mr.er.FindExtensionByName(optionsDesc.GetFullyQualifiedName(), fmt.Sprintf("[%s]", o.Name)) 970 } 971 if field == nil { 972 // can't resolve option name? skip it 973 continue 974 } 975 } 976 v, err := mr.unmarshalAny(o.Value, func(url string) (*desc.MessageDescriptor, error) { 977 // we don't want to try to recursively fetch this value's type, so if it doesn't 978 // match the type of the extension field, we'll skip it 979 if (field.GetType() == descriptorpb.FieldDescriptorProto_TYPE_GROUP || 980 field.GetType() == descriptorpb.FieldDescriptorProto_TYPE_MESSAGE) && 981 typeName(url) == field.GetMessageType().GetFullyQualifiedName() { 982 983 return field.GetMessageType(), nil 984 } 985 return nil, nil 986 }) 987 if err != nil { 988 // can't interpret value? skip it 989 continue 990 } 991 var fv interface{} 992 if field.GetType() != descriptorpb.FieldDescriptorProto_TYPE_MESSAGE && field.GetType() != descriptorpb.FieldDescriptorProto_TYPE_GROUP { 993 fv = unwrap(v) 994 if v == nil { 995 // non-wrapper type for scalar field? skip it 996 continue 997 } 998 } else { 999 fv = v 1000 } 1001 if field.IsRepeated() { 1002 dopts.TryAddRepeatedField(field, fv) // ignore any error 1003 } else { 1004 dopts.TrySetField(field, fv) // ignore any error 1005 } 1006 } 1007 return dopts 1008 } 1009 1010 func base(name string) string { 1011 pos := strings.LastIndex(name, ".") 1012 if pos >= 0 { 1013 return name[pos+1:] 1014 } 1015 return name 1016 } 1017 1018 func unwrap(msg proto.Message) interface{} { 1019 switch m := msg.(type) { 1020 case (*wrapperspb.BoolValue): 1021 return m.Value 1022 case (*wrapperspb.FloatValue): 1023 return m.Value 1024 case (*wrapperspb.DoubleValue): 1025 return m.Value 1026 case (*wrapperspb.Int32Value): 1027 return m.Value 1028 case (*wrapperspb.Int64Value): 1029 return m.Value 1030 case (*wrapperspb.UInt32Value): 1031 return m.Value 1032 case (*wrapperspb.UInt64Value): 1033 return m.Value 1034 case (*wrapperspb.BytesValue): 1035 return m.Value 1036 case (*wrapperspb.StringValue): 1037 return m.Value 1038 default: 1039 return nil 1040 } 1041 } 1042 1043 func typeName(url string) string { 1044 return url[strings.LastIndex(url, "/")+1:] 1045 }