github.com/Big-big-orange/protoreflect@v0.0.0-20240408141420-285cedfdf6a4/desc/builder/resolver.go (about)

     1  package builder
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"strings"
     7  
     8  	"google.golang.org/protobuf/encoding/protowire"
     9  	"google.golang.org/protobuf/proto"
    10  	"google.golang.org/protobuf/reflect/protoreflect"
    11  	"google.golang.org/protobuf/reflect/protoregistry"
    12  	"google.golang.org/protobuf/types/descriptorpb"
    13  
    14  	"github.com/Big-big-orange/protoreflect/desc"
    15  	"github.com/Big-big-orange/protoreflect/desc/internal"
    16  	"github.com/Big-big-orange/protoreflect/dynamic"
    17  )
    18  
    19  type dependencies struct {
    20  	descs map[*desc.FileDescriptor]struct{}
    21  	res   protoregistry.Types
    22  }
    23  
    24  func newDependencies() *dependencies {
    25  	return &dependencies{
    26  		descs: map[*desc.FileDescriptor]struct{}{},
    27  	}
    28  }
    29  
    30  func (d *dependencies) add(fd *desc.FileDescriptor) {
    31  	if _, ok := d.descs[fd]; ok {
    32  		// already added
    33  		return
    34  	}
    35  	d.descs[fd] = struct{}{}
    36  	internal.RegisterExtensionsFromImportedFile(&d.res, fd.UnwrapFile())
    37  }
    38  
    39  // dependencyResolver is the work-horse for converting a tree of builders into a
    40  // tree of descriptors. It scans a root (usually a file builder) and recursively
    41  // resolves all dependencies (references to builders in other trees as well as
    42  // references to other already-built descriptors). The result of resolution is a
    43  // file descriptor (or an error).
    44  type dependencyResolver struct {
    45  	resolvedRoots map[Builder]*desc.FileDescriptor
    46  	seen          map[Builder]struct{}
    47  	opts          BuilderOptions
    48  }
    49  
    50  func newResolver(opts BuilderOptions) *dependencyResolver {
    51  	return &dependencyResolver{
    52  		resolvedRoots: map[Builder]*desc.FileDescriptor{},
    53  		seen:          map[Builder]struct{}{},
    54  		opts:          opts,
    55  	}
    56  }
    57  
    58  func (r *dependencyResolver) resolveElement(b Builder, seen []Builder) (*desc.FileDescriptor, error) {
    59  	b = getRoot(b)
    60  
    61  	if fd, ok := r.resolvedRoots[b]; ok {
    62  		return fd, nil
    63  	}
    64  
    65  	for _, s := range seen {
    66  		if s == b {
    67  			names := make([]string, len(seen)+1)
    68  			for i, s := range seen {
    69  				names[i] = s.GetName()
    70  			}
    71  			names[len(seen)] = b.GetName()
    72  			return nil, fmt.Errorf("descriptors have cyclic dependency: %s", strings.Join(names, " ->  "))
    73  		}
    74  	}
    75  	seen = append(seen, b)
    76  
    77  	var fd *desc.FileDescriptor
    78  	var err error
    79  	switch b := b.(type) {
    80  	case *FileBuilder:
    81  		fd, err = r.resolveFile(b, b, seen)
    82  	default:
    83  		fd, err = r.resolveSyntheticFile(b, seen)
    84  	}
    85  	if err != nil {
    86  		return nil, err
    87  	}
    88  	r.resolvedRoots[b] = fd
    89  	return fd, nil
    90  }
    91  
    92  func (r *dependencyResolver) resolveFile(fb *FileBuilder, root Builder, seen []Builder) (*desc.FileDescriptor, error) {
    93  	deps := newDependencies()
    94  	// add explicit imports first
    95  	for fd := range fb.explicitImports {
    96  		deps.add(fd)
    97  	}
    98  	for dep := range fb.explicitDeps {
    99  		if dep == fb {
   100  			// ignore erroneous self references
   101  			continue
   102  		}
   103  		fd, err := r.resolveElement(dep, seen)
   104  		if err != nil {
   105  			return nil, err
   106  		}
   107  		deps.add(fd)
   108  	}
   109  	// now accumulate implicit dependencies based on other types referenced
   110  	for _, mb := range fb.messages {
   111  		if err := r.resolveTypesInMessage(root, seen, deps, mb); err != nil {
   112  			return nil, err
   113  		}
   114  	}
   115  	for _, exb := range fb.extensions {
   116  		if err := r.resolveTypesInExtension(root, seen, deps, exb); err != nil {
   117  			return nil, err
   118  		}
   119  	}
   120  	for _, sb := range fb.services {
   121  		if err := r.resolveTypesInService(root, seen, deps, sb); err != nil {
   122  			return nil, err
   123  		}
   124  	}
   125  
   126  	// finally, resolve custom options (which may refer to deps already
   127  	// computed above)
   128  	if err := r.resolveTypesInFileOptions(root, deps, fb); err != nil {
   129  		return nil, err
   130  	}
   131  
   132  	depSlice := make([]*desc.FileDescriptor, 0, len(deps.descs))
   133  	depMap := make(map[string]*desc.FileDescriptor, len(deps.descs))
   134  	for dep := range deps.descs {
   135  		isDuplicate, err := isDuplicateDependency(dep, depMap)
   136  		if err != nil {
   137  			return nil, err
   138  		}
   139  		if !isDuplicate {
   140  			depSlice = append(depSlice, dep)
   141  		}
   142  	}
   143  
   144  	fp, err := fb.buildProto(depSlice)
   145  	if err != nil {
   146  		return nil, err
   147  	}
   148  
   149  	// make sure this file name doesn't collide with any of its dependencies
   150  	fileNames := map[string]struct{}{}
   151  	for _, d := range depSlice {
   152  		addFileNames(d, fileNames)
   153  	}
   154  	unique := makeUnique(fp.GetName(), fileNames)
   155  	if unique != fp.GetName() {
   156  		fp.Name = proto.String(unique)
   157  	}
   158  
   159  	return desc.CreateFileDescriptor(fp, depSlice...)
   160  }
   161  
   162  // isDuplicateDependency checks for duplicate descriptors
   163  func isDuplicateDependency(dep *desc.FileDescriptor, depMap map[string]*desc.FileDescriptor) (bool, error) {
   164  	if _, exists := depMap[dep.GetName()]; !exists {
   165  		depMap[dep.GetName()] = dep
   166  		return false, nil
   167  	}
   168  	prevFDP := depMap[dep.GetName()].AsFileDescriptorProto()
   169  	depFDP := dep.AsFileDescriptorProto()
   170  
   171  	// temporarily reset Source Code Fields as builders does not have SourceCodeInfo
   172  	defer setSourceCodeInfo(prevFDP, nil)()
   173  	defer setSourceCodeInfo(depFDP, nil)()
   174  
   175  	if !proto.Equal(prevFDP, depFDP) {
   176  		return true, fmt.Errorf("multiple versions of descriptors found with same file name: %s", dep.GetName())
   177  	}
   178  	return true, nil
   179  }
   180  
   181  func setSourceCodeInfo(fdp *descriptorpb.FileDescriptorProto, info *descriptorpb.SourceCodeInfo) (reset func()) {
   182  	prevSourceCodeInfo := fdp.SourceCodeInfo
   183  	fdp.SourceCodeInfo = info
   184  	return func() { fdp.SourceCodeInfo = prevSourceCodeInfo }
   185  }
   186  
   187  func addFileNames(fd *desc.FileDescriptor, files map[string]struct{}) {
   188  	if _, ok := files[fd.GetName()]; ok {
   189  		// already added
   190  		return
   191  	}
   192  	files[fd.GetName()] = struct{}{}
   193  	for _, d := range fd.GetDependencies() {
   194  		addFileNames(d, files)
   195  	}
   196  }
   197  
   198  func (r *dependencyResolver) resolveSyntheticFile(b Builder, seen []Builder) (*desc.FileDescriptor, error) {
   199  	// find ancestor to temporarily attach to new file
   200  	curr := b
   201  	for curr.GetParent() != nil {
   202  		curr = curr.GetParent()
   203  	}
   204  	f := NewFile("")
   205  	switch curr := curr.(type) {
   206  	case *MessageBuilder:
   207  		f.messages = append(f.messages, curr)
   208  	case *EnumBuilder:
   209  		f.enums = append(f.enums, curr)
   210  	case *ServiceBuilder:
   211  		f.services = append(f.services, curr)
   212  	case *FieldBuilder:
   213  		if curr.IsExtension() {
   214  			f.extensions = append(f.extensions, curr)
   215  		} else {
   216  			panic("field must be added to message before calling Build()")
   217  		}
   218  	case *OneOfBuilder:
   219  		if _, ok := b.(*OneOfBuilder); ok {
   220  			panic("one-of must be added to message before calling Build()")
   221  		} else {
   222  			// b was a child of one-of which means it must have been a field
   223  			panic("field must be added to message before calling Build()")
   224  		}
   225  	case *MethodBuilder:
   226  		panic("method must be added to service before calling Build()")
   227  	case *EnumValueBuilder:
   228  		panic("enum value must be added to enum before calling Build()")
   229  	default:
   230  		panic(fmt.Sprintf("Unrecognized kind of builder: %T", b))
   231  	}
   232  	curr.setParent(f)
   233  
   234  	// don't forget to reset when done
   235  	defer func() {
   236  		curr.setParent(nil)
   237  	}()
   238  
   239  	return r.resolveFile(f, b, seen)
   240  }
   241  
   242  func (r *dependencyResolver) resolveTypesInMessage(root Builder, seen []Builder, deps *dependencies, mb *MessageBuilder) error {
   243  	for _, b := range mb.fieldsAndOneOfs {
   244  		if flb, ok := b.(*FieldBuilder); ok {
   245  			if err := r.resolveTypesInField(root, seen, deps, flb); err != nil {
   246  				return err
   247  			}
   248  		} else {
   249  			oob := b.(*OneOfBuilder)
   250  			for _, flb := range oob.choices {
   251  				if err := r.resolveTypesInField(root, seen, deps, flb); err != nil {
   252  					return err
   253  				}
   254  			}
   255  		}
   256  	}
   257  	for _, nmb := range mb.nestedMessages {
   258  		if err := r.resolveTypesInMessage(root, seen, deps, nmb); err != nil {
   259  			return err
   260  		}
   261  	}
   262  	for _, exb := range mb.nestedExtensions {
   263  		if err := r.resolveTypesInExtension(root, seen, deps, exb); err != nil {
   264  			return err
   265  		}
   266  	}
   267  	return nil
   268  }
   269  
   270  func (r *dependencyResolver) resolveTypesInExtension(root Builder, seen []Builder, deps *dependencies, exb *FieldBuilder) error {
   271  	if err := r.resolveTypesInField(root, seen, deps, exb); err != nil {
   272  		return err
   273  	}
   274  	if exb.foreignExtendee != nil {
   275  		deps.add(exb.foreignExtendee.GetFile())
   276  	} else if err := r.resolveType(root, seen, exb.localExtendee, deps); err != nil {
   277  		return err
   278  	}
   279  	return nil
   280  }
   281  
   282  func (r *dependencyResolver) resolveTypesInService(root Builder, seen []Builder, deps *dependencies, sb *ServiceBuilder) error {
   283  	for _, mtb := range sb.methods {
   284  		if err := r.resolveRpcType(root, seen, mtb.ReqType, deps); err != nil {
   285  			return err
   286  		}
   287  		if err := r.resolveRpcType(root, seen, mtb.RespType, deps); err != nil {
   288  			return err
   289  		}
   290  	}
   291  	return nil
   292  }
   293  
   294  func (r *dependencyResolver) resolveRpcType(root Builder, seen []Builder, t *RpcType, deps *dependencies) error {
   295  	if t.foreignType != nil {
   296  		deps.add(t.foreignType.GetFile())
   297  	} else {
   298  		return r.resolveType(root, seen, t.localType, deps)
   299  	}
   300  	return nil
   301  }
   302  
   303  func (r *dependencyResolver) resolveTypesInField(root Builder, seen []Builder, deps *dependencies, flb *FieldBuilder) error {
   304  	if flb.fieldType.foreignMsgType != nil {
   305  		deps.add(flb.fieldType.foreignMsgType.GetFile())
   306  	} else if flb.fieldType.foreignEnumType != nil {
   307  		deps.add(flb.fieldType.foreignEnumType.GetFile())
   308  	} else if flb.fieldType.localMsgType != nil {
   309  		if flb.fieldType.localMsgType == flb.msgType {
   310  			return r.resolveTypesInMessage(root, seen, deps, flb.msgType)
   311  		} else {
   312  			return r.resolveType(root, seen, flb.fieldType.localMsgType, deps)
   313  		}
   314  	} else if flb.fieldType.localEnumType != nil {
   315  		return r.resolveType(root, seen, flb.fieldType.localEnumType, deps)
   316  	}
   317  	return nil
   318  }
   319  
   320  func (r *dependencyResolver) resolveType(root Builder, seen []Builder, typeBuilder Builder, deps *dependencies) error {
   321  	otherRoot := getRoot(typeBuilder)
   322  	if root == otherRoot {
   323  		// local reference, so it will get resolved when we finish resolving this root
   324  		return nil
   325  	}
   326  	fd, err := r.resolveElement(otherRoot, seen)
   327  	if err != nil {
   328  		return err
   329  	}
   330  	deps.add(fd)
   331  	return nil
   332  }
   333  
   334  func (r *dependencyResolver) resolveTypesInFileOptions(root Builder, deps *dependencies, fb *FileBuilder) error {
   335  	for _, mb := range fb.messages {
   336  		if err := r.resolveTypesInMessageOptions(root, fb.origExts, deps, mb); err != nil {
   337  			return err
   338  		}
   339  	}
   340  	for _, eb := range fb.enums {
   341  		if err := r.resolveTypesInEnumOptions(root, fb.origExts, deps, eb); err != nil {
   342  			return err
   343  		}
   344  	}
   345  	for _, exb := range fb.extensions {
   346  		if err := r.resolveTypesInOptions(root, fb.origExts, deps, exb.Options); err != nil {
   347  			return err
   348  		}
   349  	}
   350  	for _, sb := range fb.services {
   351  		for _, mtb := range sb.methods {
   352  			if err := r.resolveTypesInOptions(root, fb.origExts, deps, mtb.Options); err != nil {
   353  				return err
   354  			}
   355  		}
   356  		if err := r.resolveTypesInOptions(root, fb.origExts, deps, sb.Options); err != nil {
   357  			return err
   358  		}
   359  	}
   360  	return r.resolveTypesInOptions(root, fb.origExts, deps, fb.Options)
   361  }
   362  
   363  func (r *dependencyResolver) resolveTypesInMessageOptions(root Builder, fileExts *dynamic.ExtensionRegistry, deps *dependencies, mb *MessageBuilder) error {
   364  	for _, b := range mb.fieldsAndOneOfs {
   365  		if flb, ok := b.(*FieldBuilder); ok {
   366  			if err := r.resolveTypesInOptions(root, fileExts, deps, flb.Options); err != nil {
   367  				return err
   368  			}
   369  		} else {
   370  			oob := b.(*OneOfBuilder)
   371  			for _, flb := range oob.choices {
   372  				if err := r.resolveTypesInOptions(root, fileExts, deps, flb.Options); err != nil {
   373  					return err
   374  				}
   375  			}
   376  			if err := r.resolveTypesInOptions(root, fileExts, deps, oob.Options); err != nil {
   377  				return err
   378  			}
   379  		}
   380  	}
   381  	for _, extr := range mb.ExtensionRanges {
   382  		if err := r.resolveTypesInOptions(root, fileExts, deps, extr.Options); err != nil {
   383  			return err
   384  		}
   385  	}
   386  	for _, eb := range mb.nestedEnums {
   387  		if err := r.resolveTypesInEnumOptions(root, fileExts, deps, eb); err != nil {
   388  			return err
   389  		}
   390  	}
   391  	for _, nmb := range mb.nestedMessages {
   392  		if err := r.resolveTypesInMessageOptions(root, fileExts, deps, nmb); err != nil {
   393  			return err
   394  		}
   395  	}
   396  	for _, exb := range mb.nestedExtensions {
   397  		if err := r.resolveTypesInOptions(root, fileExts, deps, exb.Options); err != nil {
   398  			return err
   399  		}
   400  	}
   401  	if err := r.resolveTypesInOptions(root, fileExts, deps, mb.Options); err != nil {
   402  		return err
   403  	}
   404  	return nil
   405  }
   406  
   407  func (r *dependencyResolver) resolveTypesInEnumOptions(root Builder, fileExts *dynamic.ExtensionRegistry, deps *dependencies, eb *EnumBuilder) error {
   408  	for _, evb := range eb.values {
   409  		if err := r.resolveTypesInOptions(root, fileExts, deps, evb.Options); err != nil {
   410  			return err
   411  		}
   412  	}
   413  	if err := r.resolveTypesInOptions(root, fileExts, deps, eb.Options); err != nil {
   414  		return err
   415  	}
   416  	return nil
   417  }
   418  
   419  func (r *dependencyResolver) resolveTypesInOptions(root Builder, fileExts *dynamic.ExtensionRegistry, deps *dependencies, opts proto.Message) error {
   420  	// nothing to see if opts is nil
   421  	if opts == nil {
   422  		return nil
   423  	}
   424  	if rv := reflect.ValueOf(opts); rv.Kind() == reflect.Ptr && rv.IsNil() {
   425  		return nil
   426  	}
   427  
   428  	ref := opts.ProtoReflect()
   429  	tags := map[int32]protoreflect.ExtensionType{}
   430  	proto.RangeExtensions(opts, func(xt protoreflect.ExtensionType, _ interface{}) bool {
   431  		num := int32(xt.TypeDescriptor().Number())
   432  		tags[num] = xt
   433  		return true
   434  	})
   435  
   436  	unk := ref.GetUnknown()
   437  	for len(unk) > 0 {
   438  		v, n := protowire.ConsumeVarint(unk)
   439  		if n < 0 {
   440  			break
   441  		}
   442  		unk = unk[n:]
   443  
   444  		num, t := protowire.DecodeTag(v)
   445  		if _, ok := tags[int32(num)]; !ok {
   446  			tags[int32(num)] = nil
   447  		}
   448  
   449  		switch t {
   450  		case protowire.VarintType:
   451  			_, n = protowire.ConsumeVarint(unk)
   452  		case protowire.Fixed64Type:
   453  			_, n = protowire.ConsumeFixed64(unk)
   454  		case protowire.BytesType:
   455  			_, n = protowire.ConsumeBytes(unk)
   456  		case protowire.StartGroupType:
   457  			_, n = protowire.ConsumeGroup(num, unk)
   458  		case protowire.EndGroupType:
   459  			// invalid encoding
   460  			break
   461  		case protowire.Fixed32Type:
   462  			_, n = protowire.ConsumeFixed32(unk)
   463  		}
   464  		if n < 0 {
   465  			break
   466  		}
   467  		unk = unk[n:]
   468  	}
   469  
   470  	msgName := string(proto.MessageName(opts))
   471  	for tag, xt := range tags {
   472  		// see if known dependencies have this option
   473  		if _, err := deps.res.FindExtensionByNumber(protoreflect.FullName(msgName), protoreflect.FieldNumber(tag)); err == nil {
   474  			// yep! nothing else to do
   475  			continue
   476  		}
   477  		// see if this extension is defined in *this* builder
   478  		if findExtension(root, msgName, tag) {
   479  			// yep!
   480  			continue
   481  		}
   482  		// see if configured extension registry knows about it
   483  		if extd := r.opts.Extensions.FindExtension(msgName, tag); extd != nil {
   484  			// extension registry recognized it!
   485  			deps.add(extd.GetFile())
   486  			continue
   487  		}
   488  		// see if given file extensions knows about it
   489  		if fileExts != nil {
   490  			extd := fileExts.FindExtension(msgName, tag)
   491  			if extd != nil {
   492  				// file extensions recognized it!
   493  				deps.add(extd.GetFile())
   494  				continue
   495  			}
   496  		}
   497  
   498  		if xt != nil {
   499  			// known extension? add its file to builder's deps
   500  			fd, err := desc.WrapFile(xt.TypeDescriptor().ParentFile())
   501  			if err != nil {
   502  				return err
   503  			}
   504  			deps.add(fd)
   505  			continue
   506  		}
   507  
   508  		if r.opts.RequireInterpretedOptions {
   509  			// we require options to be interpreted but are not able to!
   510  			return fmt.Errorf("could not interpret custom option for %s, tag %d", msgName, tag)
   511  		}
   512  	}
   513  	return nil
   514  }
   515  
   516  func findExtension(b Builder, messageName string, extTag int32) bool {
   517  	if fb, ok := b.(*FileBuilder); ok && findExtensionInFile(fb, messageName, extTag) {
   518  		return true
   519  	}
   520  	if mb, ok := b.(*MessageBuilder); ok && findExtensionInMessage(mb, messageName, extTag) {
   521  		return true
   522  	}
   523  	return false
   524  }
   525  
   526  func findExtensionInFile(fb *FileBuilder, messageName string, extTag int32) bool {
   527  	for _, extb := range fb.extensions {
   528  		if extb.GetExtendeeTypeName() == messageName && extb.number == extTag {
   529  			return true
   530  		}
   531  	}
   532  	for _, mb := range fb.messages {
   533  		if findExtensionInMessage(mb, messageName, extTag) {
   534  			return true
   535  		}
   536  	}
   537  	return false
   538  }
   539  
   540  func findExtensionInMessage(mb *MessageBuilder, messageName string, extTag int32) bool {
   541  	for _, extb := range mb.nestedExtensions {
   542  		if extb.GetExtendeeTypeName() == messageName && extb.number == extTag {
   543  			return true
   544  		}
   545  	}
   546  	for _, mb := range mb.nestedMessages {
   547  		if findExtensionInMessage(mb, messageName, extTag) {
   548  			return true
   549  		}
   550  	}
   551  	return false
   552  }